From eb2947af7babd59d9e5b0a35bca29d1b77ba6df1 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 30 Nov 2022 12:10:37 +0000 Subject: [PATCH 01/19] general optimization no_varlen embedding layernorm --- .../tensorrt/convert/emb_eltwise_layernorm.cc | 205 +++----- .../convert/preln_emb_eltwise_layernorm.cc | 6 +- .../inference/tensorrt/plugin/CMakeLists.txt | 5 +- .../plugin/emb_eltwise_layernorm_plugin.cu | 291 ---------- .../plugin/emb_eltwise_layernorm_plugin.h | 446 ---------------- .../plugin/many_emb_Layernorm_kernel.cu | 475 +++++++++++++++++ ...any_emb_Layernorm_varseqlen_kernelHFace.cu | 6 - ...any_emb_Layernorm_varseqlen_kernelMTron.cu | 6 - .../plugin/many_emb_layernorm_plugin.cu | 497 ++++++++++++++++++ .../plugin/many_emb_layernorm_plugin.h | 203 +++++++ .../many_emb_layernorm_varseqlen_plugin.cu | 90 ++-- .../many_emb_layernorm_varseqlen_plugin.h | 1 - 12 files changed, 1322 insertions(+), 909 deletions(-) delete mode 100644 paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu delete mode 100644 paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index acf6bafe06b4e..62d98591f62ec 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -13,7 +13,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/utils.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" -#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" #include "paddle/phi/core/ddim.h" @@ -36,7 +36,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { const framework::Scope& scope, bool test_mode) override { VLOG(4) << "convert fluid EmbEltwiseLayerNorm op to tensorrt layer"; - // get the presistable var's data auto GetWeight = [&](const std::string& var_name, framework::DDim* dim) -> TensorRTEngine::Weight { @@ -47,32 +46,13 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { return weight; }; - auto GetFp16Weight = [&](const std::string& var_name, - framework::DDim* dim) -> TensorRTEngine::Weight { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - *dim = temp_tensor->dims(); - auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor); - return weight; - }; - - auto GetFp32Weight = [&](const std::string& var_name, - framework::DDim* dim) -> TensorRTEngine::Weight { - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - *dim = temp_tensor->dims(); - auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); - return weight; - }; - framework::OpDesc op_desc(op, nullptr); auto pos_id_name = engine_->tensorrt_transformer_posid(); auto mask_id_name = engine_->tensorrt_transformer_maskid(); bool flag_varseqlen = engine_->use_varseqlen() && pos_id_name != "" && mask_id_name != ""; - bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - int hidden = 0; - // Declare inputs + // bool with_fp16 = engine_->WithFp16() && + // !engine_->disable_trt_plugin_fp16(); int hidden = 0; Declare inputs std::vector input_ids; // Declare inputs_weight @@ -95,55 +75,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { if (flag_varseqlen) { engine_->SetITensor("pos_id", engine_->GetITensor(pos_id_name)); engine_->SetITensor("mask_id", engine_->GetITensor(mask_id_name)); - - auto mask_id_tensor = engine_->GetITensor("mask_id"); - auto mask_dims = mask_id_tensor->getDimensions(); - auto slice_start_dims = mask_dims; - auto slice_stride_dims = mask_dims; - - for (int i = 0; i < mask_dims.nbDims; i++) { - slice_start_dims.d[i] = 0; - slice_stride_dims.d[i] = 1; - } - - auto* shape_tensor = Shape(mask_id_tensor); - std::vector size_vec_tensor; - std::vector start_vec_tensor; - for (int i = 0; i < mask_dims.nbDims; i++) { - size_vec_tensor.push_back(Add1DConstantLayer(1)); - start_vec_tensor.push_back(Add1DConstantLayer(0)); - } - size_vec_tensor[1] = GetEleTensorOfShape(shape_tensor, 1); - auto size_tensor = Concat(size_vec_tensor); - auto start_tensor = Concat(start_vec_tensor); - - auto slice_layer = - TRT_ENGINE_ADD_LAYER(engine_, - Slice, - *mask_id_tensor, - slice_start_dims, - slice_start_dims, - slice_stride_dims); // unuseful slice_start_dims - slice_layer->setInput(1, *start_tensor); - slice_layer->setInput(2, *size_tensor); - slice_layer->setName( - ("Embeltwise_slice_layer (Output: slice_max_seqlen " + - op_desc.Output("Out")[0] + ")") - .c_str()); - engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f); - - auto* reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *slice_layer->getOutput(0)); - nvinfer1::Dims shape_dim; - shape_dim.nbDims = 1; - shape_dim.d[0] = -1; - reshape_layer->setReshapeDimensions(shape_dim); - reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " + - op_desc.Output("Out")[0] + ")") - .c_str()); - engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); - engine_->SetITensor("max_seqlen_tensor", reshape_layer->getOutput(0)); - for (int i = 0; i < input_num; i++) { auto input_tensor = engine_->GetITensor(id_names[i]); weight = GetWeight(emb_names[i], &emb_dims); @@ -156,7 +87,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { input_embs.push_back(weight.get()); emb_sizes.push_back(weight.get().count); } - hidden = emb_dims[1]; + // hidden = emb_dims[1]; } bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); @@ -206,26 +137,29 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { plugin_ptr->fields = fields.data(); std::vector plugin_inputs = input_ids; - plugin_inputs.emplace_back(engine_->GetITensor( - "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 - + plugin_inputs.emplace_back( + engine_->GetITensor("mask_id")); // input mask_id auto creator = GetPluginRegistry()->getPluginCreator( - "ManyEmbLayerNormPluginDynamic", "1"); - auto plugin_obj = - creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); + "ManyEmbLayerNormVarlenPluginDynamic", "1"); + auto plugin_obj = creator->createPlugin( + "ManyEmbLayerNormVarlenPluginDynamic", plugin_ptr); auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); - plugin_layer->setName(("ManyEmbLayerNormPluginDynamic_V1(Output: " + + plugin_layer->setName(("ManyEmbLayerNormVarlenPluginDynamicV1(Output: " + op_desc.Output("Out")[0] + ")") .c_str()); free(plugin_ptr); if (enable_int8) { float out_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); - engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_scale); - engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale); + engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), + out_scale); // output + engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), + out_scale); // mask + engine_->SetTensorDynamicRange(plugin_layer->getOutput(2), + out_scale); // max seqlen } if (engine_->with_interleaved()) { VLOG(4) << "fused emb_eltwise_layernorm op: use_varseqlen and " @@ -249,54 +183,89 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "ManyEmbLayerNormPluginDynamic_V1", - {output_name, std::string("qkv_plugin_mask")}, + {output_name, + std::string("qkv_plugin_mask"), + std::string("max_seqlen_tensor")}, test_mode); } } else { for (int i = 0; i < input_num; i++) { - if (with_fp16) { - weight = GetFp16Weight(emb_names[i], &emb_dims); - } else { - weight = GetFp32Weight(emb_names[i], &emb_dims); - } - input_ids.push_back(engine_->GetITensor(id_names[i])); + auto input_tensor = engine_->GetITensor(id_names[i]); + weight = GetWeight(emb_names[i], &emb_dims); + input_ids.push_back(input_tensor); input_embs.push_back(weight.get()); emb_sizes.push_back(weight.get().count); - hidden = emb_dims[1]; - } - if (with_fp16) { - bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims); - scale_weight = - GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims); - } else { - bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims); - scale_weight = - GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims); + // hidden = emb_dims[1]; } + bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); bias_size = phi::product(bias_dims); scale_size = phi::product(scale_dims); - float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); - plugin::DynamicPluginTensorRT* plugin = nullptr; - std::vector input_embs_data; - for (size_t i = 0; i < input_embs.size(); ++i) { - input_embs_data.push_back(const_cast( - reinterpret_cast(input_embs[i].values))); + + int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); + if (enable_int8) { + output_fp16 = 1; + } + PADDLE_ENFORCE_EQ( + output_fp16, + 1, + platform::errors::InvalidArgument( + "Only Precision::KHalf(fp16) is supported when infering " + "ernie(bert) model with config.EnableVarseqlen(). " + "But Precision::KFloat32 is setted.")); + + std::vector fields; + std::vector temp_fields_keys; + fields.emplace_back("bert_embeddings_layernorm_beta", + bias_weight.get().values, + GetPluginFieldType(bias_weight.get().type), + static_cast(bias_size)); + fields.emplace_back("bert_embeddings_layernorm_gamma", + scale_weight.get().values, + GetPluginFieldType(scale_weight.get().type), + static_cast(scale_size)); + fields.emplace_back( + "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1); + for (int i = 0; i < input_num; ++i) { + temp_fields_keys.push_back("bert_embeddings_word_embeddings_" + + std::to_string(i)); + fields.emplace_back(temp_fields_keys.rbegin()->c_str(), + input_embs[i].values, + GetPluginFieldType(input_embs[i].type), + static_cast(emb_sizes[i])); + } + + nvinfer1::PluginFieldCollection* plugin_ptr = + static_cast( + malloc(sizeof(*plugin_ptr) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_ptr->nbFields = static_cast(fields.size()); + plugin_ptr->fields = fields.data(); + + std::vector plugin_inputs = input_ids; + + auto creator = GetPluginRegistry()->getPluginCreator( + "ManyEmbLayerNormPluginDynamic", "1"); + auto plugin_obj = + creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); + + plugin_layer->setName(("ManyEmbLayerNormPluginDynamicV1(Output: " + + op_desc.Output("Out")[0] + ")") + .c_str()); + free(plugin_ptr); + if (enable_int8) { + float out_scale = + PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); + engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), + out_scale); // output } - plugin = new plugin::EmbEltwiseLayernormPluginDynamic( - input_embs_data, - const_cast(static_cast(bias_weight.get().values)), - const_cast( - static_cast(scale_weight.get().values)), - emb_sizes, - bias_size, - scale_size, - hidden, - eps, - with_fp16); - layer = engine_->AddDynamicPlugin(input_ids.data(), input_num, plugin); + layer = plugin_layer; auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput( - layer, "emb_eltwise_layernorm", {output_name}, test_mode); + layer, "ManyEmbLayerNormPluginDynamicV1", {output_name}, test_mode); } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 07a97b32f702b..43c62523b8fce 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -194,10 +194,10 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { "max_seqlen_tensor")); // max_seqlen, eval_placeholder_3 auto creator = GetPluginRegistry()->getPluginCreator( - "ManyEmbLayerNormPluginDynamic", "2"); + "ManyEmbLayerNormVarlenPluginDynamic", "2"); - auto plugin_obj = - creator->createPlugin("ManyEmbLayerNormPluginDynamic", plugin_ptr); + auto plugin_obj = creator->createPlugin( + "ManyEmbLayerNormVarlenPluginDynamic", plugin_ptr); auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index a72880780d81e..40f9ef127f5cc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -11,7 +11,6 @@ list( group_norm_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu - emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu hard_swish_op_plugin.cu @@ -37,7 +36,9 @@ list( merge_layernorm_op_plugin.cu skip_merge_layernorm_op_plugin.cu generic_plugin.cu - lookup_table.cu) + lookup_table.cu + many_emb_layernorm_plugin.cu + many_emb_Layernorm_kernel.cu) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu deleted file mode 100644 index b4f8ae0432a32..0000000000000 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include // NOLINT -#include -#include - -#include "glog/logging.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" -#include "paddle/fluid/operators/math/bert_encoder_functor.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -// Dynamic shape plugin requires TRT version greater than 6.0. -#if IS_TRT_VERSION_GE(6000) - -template -void EmbEltwiseLayernormPluginDynamicImpl::shareGPUData( - const EmbEltwiseLayernormPluginDynamicImplBase *anthor) { - auto *ptr = - dynamic_cast *>(anthor); - if (!ptr->is_initialized_) { - return; - } - embs_gpu_ = ptr->embs_gpu_; - scale_gpu_ = ptr->scale_gpu_; - bias_gpu_ = ptr->bias_gpu_; - int input_num = embs_.size(); - in_ptr_tensor_.Resize({input_num}); - emb_ptr_tensor_.ShareDataWith(ptr->emb_ptr_tensor_); -} - -template -int EmbEltwiseLayernormPluginDynamicImpl::initialize() { - if (is_initialized_) { - return 0; - } - embs_gpu_.resize(embs_.size()); - for (int i = 0; i < embs_.size(); i++) { - if (embs_[i]) { - T *host_ptr = embs_[i]; - auto size = emb_sizes_[i]; - - cudaMalloc(&embs_gpu_[i], sizeof(T) * size); - cudaMemcpy( - embs_gpu_[i], host_ptr, size * sizeof(T), cudaMemcpyHostToDevice); - } - } - - if (bias_) { - cudaMalloc(&bias_gpu_, sizeof(T) * bias_size_); - cudaMemcpy( - bias_gpu_, bias_, bias_size_ * sizeof(T), cudaMemcpyHostToDevice); - } - if (scale_) { - cudaMalloc(&scale_gpu_, sizeof(T) * scale_size_); - cudaMemcpy( - scale_gpu_, scale_, scale_size_ * sizeof(T), cudaMemcpyHostToDevice); - } - - int input_num = embs_.size(); - in_ptr_tensor_.Resize({input_num}); - emb_ptr_tensor_.Resize({input_num}); - cudaGetDevice(&device_id_); - auto emb_ptr_gpu_d = - emb_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id_)); - cudaMemcpy(emb_ptr_gpu_d, - embs_gpu_.data(), - sizeof(uintptr_t) * input_num, - cudaMemcpyHostToDevice); - is_initialized_ = true; - return 0; -} - -template -void EmbEltwiseLayernormPluginDynamicImpl::terminate() { - for (int i = 0; i < embs_gpu_.size(); ++i) { - if (embs_gpu_[i]) { - cudaFree(embs_gpu_[i]); - embs_gpu_[i] = nullptr; - } - } - - if (bias_gpu_) { - cudaFree(bias_gpu_); - bias_gpu_ = nullptr; - } - - if (scale_gpu_) { - cudaFree(scale_gpu_); - scale_gpu_ = nullptr; - } -} - -template -int EmbEltwiseLayernormPluginDynamicImpl::enqueue( - const nvinfer1::PluginTensorDesc *input_desc, - const nvinfer1::PluginTensorDesc *output_desc, - const void *const *inputs, - void *const *outputs, - void *workspace, - cudaStream_t stream) TRT_NOEXCEPT { - auto id_dims = input_desc[0].dims; - int batch = id_dims.d[0]; - int seq_len = id_dims.d[1]; - int input_num = embs_.size(); - cudaGetDevice(&device_id_); - auto in_ptr_gpu_d = - in_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id_)); - auto emb_ptr_gpu_d = - emb_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id_)); - - cudaMemcpyAsync(in_ptr_gpu_d, - reinterpret_cast(inputs), - sizeof(uintptr_t) * input_num, - cudaMemcpyHostToDevice, - stream); - - auto out_type = output_desc[0].type; - - if (std::is_same::value) { - PADDLE_ENFORCE_EQ( - out_type == nvinfer1::DataType::kFLOAT, - true, - platform::errors::InvalidArgument( - "The EmbEltwiseLayernorm Plugin only support fp32 input.")); - } else if (std::is_same::value) { - PADDLE_ENFORCE_EQ( - out_type == nvinfer1::DataType::kHALF, - true, - platform::errors::InvalidArgument( - "The EmbEltwiseLayernorm Plugin only support fp16 input.")); - } else { - PADDLE_THROW(platform::errors::Fatal( - "Unsupport data type, the out type of EmbEltwiseLayernorm should be " - "float or half.")); - } - - auto *output_d = reinterpret_cast(outputs[0]); - - operators::math::EmbEltwiseLayerNormFunctor emb_eltwise_layernorm_func; - emb_eltwise_layernorm_func(batch, - seq_len, - hidden_size_, - in_ptr_gpu_d, - scale_gpu_, - bias_gpu_, - emb_ptr_gpu_d, - output_d, - eps_, - input_num, - stream); - return cudaGetLastError() != cudaSuccess; -} - -template class EmbEltwiseLayernormPluginDynamicImpl; -#ifdef TRT_PLUGIN_FP16_AVALIABLE -template class EmbEltwiseLayernormPluginDynamicImpl; -#endif - -int EmbEltwiseLayernormPluginDynamic::initialize() TRT_NOEXCEPT { - impl_->initialize(); - - return 0; -} - -void EmbEltwiseLayernormPluginDynamic::terminate() TRT_NOEXCEPT { - impl_->terminate(); -} - -nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions( - int output_index, - const nvinfer1::DimsExprs *inputs, - int nb_inputs, - nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { // NOLINT - PADDLE_ENFORCE_EQ(output_index, - 0, - platform::errors::InvalidArgument( - "There is only one output of the EmbEltwiseLayernorm, " - "so the index should be zero," - "but it's (%d)", - output_index)); - nvinfer1::DimsExprs ret; - ret.nbDims = 3; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = expr_builder.constant(hidden_size_); - return ret; -} - -bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( - int pos, - const nvinfer1::PluginTensorDesc *in_out, - int nb_inputs, - int nb_outputs) TRT_NOEXCEPT { - PADDLE_ENFORCE_NOT_NULL( - in_out, - platform::errors::InvalidArgument( - "The input of swish plugin shoule not be nullptr.")); - PADDLE_ENFORCE_EQ(nb_outputs, - 1, - platform::errors::InvalidArgument( - "The EmbEltwiseLayerNorm's output should be one" - "but it's (%d) outputs.", - nb_outputs)); - int all_nums = nb_inputs + nb_outputs; - PADDLE_ENFORCE_LT( - pos, - all_nums, - platform::errors::InvalidArgument("The pos(%d) should be less than the " - "num(%d) of the input and the output.", - pos, - all_nums)); - const nvinfer1::PluginTensorDesc &desc = in_out[pos]; - if (desc.format != nvinfer1::TensorFormat::kLINEAR) { - return false; - } - - if (pos == 0) { - return desc.type == nvinfer1::DataType::kINT32; - } - - const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; - if (pos < all_nums - 1) { - return desc.type == nvinfer1::DataType::kINT32 && - desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1]; - } - // output - if (pos == all_nums - 1) { - if (with_fp16_ == false) { - return desc.type == nvinfer1::DataType::kFLOAT; - } else { - return desc.type == nvinfer1::DataType::kHALF; - } - } - return false; -} - -nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType( - int index, - const nvinfer1::DataType *input_types, - int nb_inputs) const TRT_NOEXCEPT { - PADDLE_ENFORCE_EQ( - index, - 0, - platform::errors::InvalidArgument( - "The EmbEltwiseLayernorm Plugin only has one output, so the " - "index value should be 0, but get %d.", - index)); - if (with_fp16_) - return nvinfer1::DataType::kHALF; - else - return nvinfer1::DataType::kFLOAT; -} - -int EmbEltwiseLayernormPluginDynamic::enqueue( - const nvinfer1::PluginTensorDesc *input_desc, - const nvinfer1::PluginTensorDesc *output_desc, - const void *const *inputs, - void *const *outputs, - void *workspace, - cudaStream_t stream) TRT_NOEXCEPT { - impl_->enqueue(input_desc, output_desc, inputs, outputs, workspace, stream); - return cudaGetLastError() != cudaSuccess; -} - -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h deleted file mode 100644 index d0815798a6e47..0000000000000 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h +++ /dev/null @@ -1,446 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/inference/tensorrt/engine.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -#if IS_TRT_VERSION_GE(6000) - -class EmbEltwiseLayernormPluginDynamicImplBase { - public: - EmbEltwiseLayernormPluginDynamicImplBase() {} - virtual ~EmbEltwiseLayernormPluginDynamicImplBase() {} - - virtual int initialize() = 0; - virtual void terminate() = 0; - virtual int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) = 0; - virtual void shareGPUData( - const EmbEltwiseLayernormPluginDynamicImplBase* anthor) = 0; -}; - -template -class EmbEltwiseLayernormPluginDynamicImpl - : public EmbEltwiseLayernormPluginDynamicImplBase { - public: - explicit EmbEltwiseLayernormPluginDynamicImpl(std::vector input_embs, - T* bias, - T* scale, - std::vector emb_sizes, - int bias_size, - int scale_size, - int hidden_size, - float eps) - : embs_(input_embs), - bias_(bias), - scale_(scale), - emb_sizes_(emb_sizes), - bias_size_(bias_size), - scale_size_(scale_size), - hidden_size_(hidden_size), - eps_(eps) {} - - ~EmbEltwiseLayernormPluginDynamicImpl() {} - - int initialize(); - void terminate(); - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) TRT_NOEXCEPT; - void shareGPUData(const EmbEltwiseLayernormPluginDynamicImplBase* anthor); - - private: - std::vector embs_; - T* bias_{nullptr}; - T* scale_{nullptr}; - - // data on devices - T* bias_gpu_{nullptr}; - T* scale_gpu_{nullptr}; - std::vector embs_gpu_; - - std::vector emb_sizes_; - int bias_size_; - int scale_size_; - int hidden_size_; - float eps_; - - phi::DenseTensor in_ptr_tensor_, emb_ptr_tensor_; - int device_id_{0}; - bool is_initialized_{false}; -}; - -class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { - public: - explicit EmbEltwiseLayernormPluginDynamic(std::vector input_embs, - void* bias, - void* scale, - std::vector emb_sizes, - int bias_size, - int scale_size, - int hidden_size, - float eps, - bool with_fp16) - : embs_(input_embs), - bias_(bias), - scale_(scale), - emb_sizes_(emb_sizes), - bias_size_(bias_size), - scale_size_(scale_size), - hidden_size_(hidden_size), - eps_(eps), - own_host_buff_(false) { - with_fp16_ = with_fp16; - if (with_fp16_) { -#ifdef TRT_PLUGIN_FP16_AVALIABLE - VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16"; - instantiateImpl(); -#else - PADDLE_THROW(platform::errors::Fatal( - "The Ernie(Bert) tensorRT plugin should be " - "complied with CUDA version >= 10.0 when running with fp16. " - "Please recomplie it or try to use fp32 by set " - "config.EnableTensorRtEngine(1 << 30, 1, 5, " - "AnalysisConfig::Precision::kFloat32, false, false) ")); -#endif - } else { - VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32"; - instantiateImpl(); - } - } - - EmbEltwiseLayernormPluginDynamic(void const* serial_data, - size_t serial_length) - : own_host_buff_(true) { - // the first var is with_fp16, we will use it. - DeserializeValue(&serial_data, &serial_length, &with_fp16_); - DeserializeValue(&serial_data, &serial_length, &emb_sizes_); - DeserializeValue(&serial_data, &serial_length, &bias_size_); - DeserializeValue(&serial_data, &serial_length, &scale_size_); - - embs_.resize(emb_sizes_.size()); - - if (with_fp16_) { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - auto size = emb_sizes_[i]; - auto ptr = new half[size]; - memcpy(ptr, serial_data, sizeof(half) * size); - embs_[i] = ptr; - reinterpret_cast(serial_data) += size * sizeof(half); - serial_length -= size * sizeof(half); - } - if (bias_size_) { - bias_ = new half[bias_size_]; - memcpy(bias_, serial_data, sizeof(half) * bias_size_); - } - reinterpret_cast(serial_data) += bias_size_ * sizeof(half); - serial_length -= bias_size_ * sizeof(half); - - if (scale_size_) { - scale_ = new half[scale_size_]; - memcpy(scale_, serial_data, sizeof(half) * scale_size_); - } - reinterpret_cast(serial_data) += scale_size_ * sizeof(half); - serial_length -= scale_size_ * sizeof(half); - } else { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - auto size = emb_sizes_[i]; - auto ptr = new float[size]; - memcpy(ptr, serial_data, sizeof(float) * size); - embs_[i] = ptr; - reinterpret_cast(serial_data) += size * sizeof(float); - serial_length -= size * sizeof(float); - } - if (bias_size_) { - bias_ = new float[bias_size_]; - memcpy(bias_, serial_data, sizeof(float) * bias_size_); - } - reinterpret_cast(serial_data) += bias_size_ * sizeof(float); - serial_length -= bias_size_ * sizeof(float); - - if (scale_size_) { - scale_ = new float[scale_size_]; - memcpy(scale_, serial_data, sizeof(float) * scale_size_); - } - reinterpret_cast(serial_data) += - scale_size_ * sizeof(float); - serial_length -= scale_size_ * sizeof(float); - } - - DeserializeValue(&serial_data, &serial_length, &hidden_size_); - DeserializeValue(&serial_data, &serial_length, &eps_); - - if (with_fp16_) { -#ifdef TRT_PLUGIN_FP16_AVALIABLE - instantiateImpl(); -#else - PADDLE_THROW(platform::errors::Fatal( - "The Ernie(Bert) tensorRT plugin should be " - "complied with CUDA version >= 10.0 when running with fp16. " - "Please recomplie it or try to use fp32 by set " - "config.EnableTensorRtEngine(1 << 30, 1, 5, " - "AnalysisConfig::Precision::kFloat32, false, false) ")); -#endif - } else { - instantiateImpl(); - } - } - - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { - auto ptr = new EmbEltwiseLayernormPluginDynamic(embs_, - bias_, - scale_, - emb_sizes_, - bias_size_, - scale_size_, - hidden_size_, - eps_, - with_fp16_); - ptr->shareGPUData(this); - return ptr; - } - - const char* getPluginType() const TRT_NOEXCEPT override { - return "fused_embedding_eltwise_layernorm_plugin"; - } - int getNbOutputs() const TRT_NOEXCEPT override { return 1; } - int initialize() TRT_NOEXCEPT override; - void terminate() TRT_NOEXCEPT override; - - size_t getSerializationSize() const TRT_NOEXCEPT override { - int sum_num = 0; - sum_num += SerializedSize(with_fp16_); - sum_num += SerializedSize(emb_sizes_); - - if (with_fp16_) { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - sum_num += emb_sizes_[i] * sizeof(half); - } - sum_num += (bias_size_ + scale_size_) * sizeof(half); - } else { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - sum_num += emb_sizes_[i] * sizeof(float); - } - sum_num += (bias_size_ + scale_size_) * sizeof(float); - } - - sum_num += SerializedSize(bias_size_); - sum_num += SerializedSize(scale_size_); - - sum_num += SerializedSize(hidden_size_); - sum_num += SerializedSize(eps_); - - return sum_num; - } - - void serialize(void* buffer) const TRT_NOEXCEPT override { - // the first var is for with_fp16, we will use it later; - SerializeValue(&buffer, with_fp16_); - SerializeValue(&buffer, emb_sizes_); - SerializeValue(&buffer, bias_size_); - SerializeValue(&buffer, scale_size_); - if (with_fp16_) { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - auto size = emb_sizes_[i]; - for (int j = 0; j < size; ++j) { - SerializeValue(&buffer, reinterpret_cast(embs_[i])[j]); - } - } - for (int i = 0; i < bias_size_; ++i) { - SerializeValue(&buffer, reinterpret_cast(bias_)[i]); - } - - for (int i = 0; i < scale_size_; ++i) { - SerializeValue(&buffer, reinterpret_cast(scale_)[i]); - } - } else { - for (size_t i = 0; i < emb_sizes_.size(); i++) { - auto size = emb_sizes_[i]; - for (int j = 0; j < size; ++j) { - SerializeValue(&buffer, reinterpret_cast(embs_[i])[j]); - } - } - for (int i = 0; i < bias_size_; ++i) { - SerializeValue(&buffer, reinterpret_cast(bias_)[i]); - } - - for (int i = 0; i < scale_size_; ++i) { - SerializeValue(&buffer, reinterpret_cast(scale_)[i]); - } - } - - SerializeValue(&buffer, hidden_size_); - SerializeValue(&buffer, eps_); - } - - nvinfer1::DimsExprs getOutputDimensions( - int output_index, - const nvinfer1::DimsExprs* inputs, - int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) // NOLINT - TRT_NOEXCEPT override; - - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* in_out, - int nb_inputs, - int nb_outputs) TRT_NOEXCEPT override; - - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nb_inputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nb_outputs) TRT_NOEXCEPT override {} - - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nb_inputs, - const nvinfer1::PluginTensorDesc* outputs, - int nb_outputs) const TRT_NOEXCEPT override { - return 0; - } - - int enqueue(const nvinfer1::PluginTensorDesc* input_desc, - const nvinfer1::PluginTensorDesc* output_desc, - const void* const* inputs, - void* const* outputs, - void* workspace, - cudaStream_t stream) TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, - const nvinfer1::DataType* input_types, - int nb_inputs) const - TRT_NOEXCEPT override; - - void destroy() TRT_NOEXCEPT override { - if (own_host_buff_) { - if (with_fp16_) { - for (auto ptr : embs_) { - delete[] reinterpret_cast(ptr); - } - delete[] reinterpret_cast(bias_); - delete[] reinterpret_cast(scale_); - } else { - for (auto ptr : embs_) { - delete[] reinterpret_cast(ptr); - } - delete[] reinterpret_cast(bias_); - delete[] reinterpret_cast(scale_); - } - } - delete impl_; - delete this; - } - - private: - std::vector embs_; - void* bias_{nullptr}; - void* scale_{nullptr}; - - std::vector emb_sizes_; - int bias_size_; - int scale_size_; - int hidden_size_; - float eps_; - - bool own_host_buff_{false}; - EmbEltwiseLayernormPluginDynamicImplBase* impl_{nullptr}; - - void shareGPUData(const EmbEltwiseLayernormPluginDynamic* anthor) { - impl_->shareGPUData(anthor->impl_); - } - - template - void instantiateImpl() { - std::vector embs; - embs.resize(embs_.size()); - for (size_t i = 0; i < embs_.size(); ++i) { - embs[i] = reinterpret_cast(embs_[i]); - } - impl_ = new EmbEltwiseLayernormPluginDynamicImpl( - embs, - reinterpret_cast(bias_), - reinterpret_cast(scale_), - emb_sizes_, - bias_size_, - scale_size_, - hidden_size_, - eps_); - } -}; - -class EmbEltwiseLayernormPluginDynamicCreator - : public nvinfer1::IPluginCreator { - public: - EmbEltwiseLayernormPluginDynamicCreator() {} - const char* getPluginName() const TRT_NOEXCEPT override { - return "fused_embedding_eltwise_layernorm_plugin"; - } - - const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } - - const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { - return &field_collection_; - } - - nvinfer1::IPluginV2* createPlugin(const char* name, - const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override { - return nullptr; - } - - nvinfer1::IPluginV2* deserializePlugin(const char* name, - const void* serial_data, - size_t serial_length) - TRT_NOEXCEPT override { - return new EmbEltwiseLayernormPluginDynamic(serial_data, serial_length); - } - - void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { - plugin_namespace_ = lib_namespace; - } - - const char* getPluginNamespace() const TRT_NOEXCEPT override { - return plugin_namespace_.c_str(); - } - - private: - std::string plugin_namespace_; - std::string plugin_name_; - nvinfer1::PluginFieldCollection field_collection_; - std::vector plugin_attributes_; -}; - -REGISTER_TRT_PLUGIN_V2(EmbEltwiseLayernormPluginDynamicCreator); - -#endif -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu new file mode 100644 index 0000000000000..46ec9ef7e7509 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_kernel.cu @@ -0,0 +1,475 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +template +__global__ void embLayerNormKernel_2(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output) { + T const rld = T(1.f) / T(ld); + cub::Sum pairSum; + int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds0)[seqPos] < 0 || + static_cast(inputIds0)[seqPos] >= IdsSize0) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds0)[seqPos]; + } + + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds1)[seqPos]; + } + } + __syncthreads(); + + // offset into embeddings is given by wordId * hidden_size + int32_t const outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + int32_t const offset0 = word_id[0] * ld; + T val = mIdsEmbDev0[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev1[offset1 + it]; + + output[outOffset + it] = val; + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +__global__ void embLayerNormKernel_3(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output) { + T const rld = T(1.f) / T(ld); + cub::Sum pairSum; + int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds0)[seqPos] < 0 || + static_cast(inputIds0)[seqPos] >= IdsSize0) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds0)[seqPos]; + } + + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds1)[seqPos]; + } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[2] = static_cast(inputIds2)[seqPos]; + } + } + __syncthreads(); + + // offset into embeddings is given by wordId * hidden_size + int32_t const outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + int32_t const offset0 = word_id[0] * ld; + T val = mIdsEmbDev0[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev1[offset1 + it]; + int32_t const offset2 = word_id[2] * ld; + val += mIdsEmbDev2[offset2 + it]; + + output[outOffset + it] = val; + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +__global__ void embLayerNormKernel_4(int32_t ld, + int32_t const* inputIds0, + int32_t const* inputIds1, + int32_t const* inputIds2, + int32_t const* inputIds3, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output) { + T const rld = T(1.f) / T(ld); + cub::Sum pairSum; + int32_t const seqPos = blockIdx.y * gridDim.x + blockIdx.x; + extern __shared__ int32_t word_id[]; + + if (threadIdx.x == 0) { + if (static_cast(inputIds0)[seqPos] < 0 || + static_cast(inputIds0)[seqPos] >= IdsSize0) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[0] = static_cast(inputIds0)[seqPos]; + } + + if (static_cast(inputIds1)[seqPos] < 0 || + static_cast(inputIds1)[seqPos] >= IdsSize1) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[1] = static_cast(inputIds1)[seqPos]; + } + + if (static_cast(inputIds2)[seqPos] < 0 || + static_cast(inputIds2)[seqPos] >= IdsSize2) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[2] = static_cast(inputIds2)[seqPos]; + } + + if (static_cast(inputIds3)[seqPos] < 0 || + static_cast(inputIds3)[seqPos] >= IdsSize3) { + printf( + "Error!!!!!!(embLayerNormVarPlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } else { + word_id[3] = static_cast(inputIds3)[seqPos]; + } + } + __syncthreads(); + + // offset into embeddings is given by wordId * hidden_size + int32_t const outOffset = seqPos * ld; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + kvp threadData(0, 0); + for (int32_t it = threadIdx.x; it < ld; it += TPB) { + int32_t const offset0 = word_id[0] * ld; + T val = mIdsEmbDev0[offset0 + it]; + int32_t const offset1 = word_id[1] * ld; + val += mIdsEmbDev1[offset1 + it]; + int32_t const offset2 = word_id[2] * ld; + val += mIdsEmbDev2[offset2 + it]; + int32_t const offset3 = word_id[3] * ld; + val += mIdsEmbDev3[offset3 + it]; + + output[outOffset + it] = val; + T const rldval = rld * val; + threadData = pairSum(threadData, kvp(rldval, rldval * val)); + } + + // layer norm on the sum + layerNorm(threadData, ld, outOffset, beta, gamma, output); +} + +template +int32_t embSkipLayerNorm_2(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + int32_t IdsSize0, + int32_t IdsSize1, + T* output) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * nbLookupTables; + embLayerNormKernel_2<<>>(ld, + inputIds0, + inputIds1, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + IdsSize0, + IdsSize1, + output); + return cudaPeekAtLastError(); +} + +template +int32_t embSkipLayerNorm_3(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int const* inputIds2, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + T* output) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * nbLookupTables; + embLayerNormKernel_3<<>>(ld, + inputIds0, + inputIds1, + inputIds2, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + IdsSize0, + IdsSize1, + IdsSize2, + output); + return cudaPeekAtLastError(); +} + +template +int32_t embSkipLayerNorm_4(cudaStream_t stream, + int32_t ld, + int32_t B, + int32_t S, + int const* inputIds0, + int const* inputIds1, + int const* inputIds2, + int const* inputIds3, + int32_t nbLookupTables, + float const* beta, + float const* gamma, + T const* mIdsEmbDev0, + T const* mIdsEmbDev1, + T const* mIdsEmbDev2, + T const* mIdsEmbDev3, + int32_t IdsSize0, + int32_t IdsSize1, + int32_t IdsSize2, + int32_t IdsSize3, + T* output) { + constexpr int32_t tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + size_t cache_size = sizeof(int32_t) * nbLookupTables; + embLayerNormKernel_4<<>>(ld, + inputIds0, + inputIds1, + inputIds2, + inputIds3, + beta, + gamma, + mIdsEmbDev0, + mIdsEmbDev1, + mIdsEmbDev2, + mIdsEmbDev3, + IdsSize0, + IdsSize1, + IdsSize2, + IdsSize3, + output); + return cudaPeekAtLastError(); +} + +template int32_t embSkipLayerNorm_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNorm_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNorm_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + float const*, + float const*, + float const*, + float const*, + int32_t, + int32_t, + int32_t, + int32_t, + float*); + +template int32_t embSkipLayerNorm_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + int32_t, + int32_t, + half*); + +template int32_t embSkipLayerNorm_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + half*); + +template int32_t embSkipLayerNorm_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + half const*, + half const*, + half const*, + half const*, + int32_t, + int32_t, + int32_t, + int32_t, + half*); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu index 1a23755000c28..e2155025cb5f1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu @@ -33,7 +33,6 @@ template __global__ void embLayerNormKernelHFace_2(int32_t ld, int32_t const* inputIds0, int32_t const* inputIds1, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -93,7 +92,6 @@ __global__ void embLayerNormKernelHFace_3(int32_t ld, int32_t const* inputIds0, int32_t const* inputIds1, int32_t const* inputIds2, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -168,7 +166,6 @@ __global__ void embLayerNormKernelHFace_4(int32_t ld, int32_t const* inputIds1, int32_t const* inputIds2, int32_t const* inputIds3, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -273,7 +270,6 @@ int32_t embSkipLayerNormHFace_2(cudaStream_t stream, <<>>(ld, inputIds0, inputIds1, - nbLookupTables, beta, gamma, mIdsEmbDev0, @@ -311,7 +307,6 @@ int32_t embSkipLayerNormHFace_3(cudaStream_t stream, inputIds0, inputIds1, inputIds2, - nbLookupTables, beta, gamma, mIdsEmbDev0, @@ -355,7 +350,6 @@ int32_t embSkipLayerNormHFace_4(cudaStream_t stream, inputIds1, inputIds2, inputIds3, - nbLookupTables, beta, gamma, mIdsEmbDev0, diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu index acdf9cc5a269c..ed5b45bba3ee5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu @@ -33,7 +33,6 @@ template __global__ void embLayerNormKernelMTron_2(int32_t ld, int32_t const* inputIds0, int32_t const* inputIds1, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -95,7 +94,6 @@ __global__ void embLayerNormKernelMTron_3(int32_t ld, int32_t const* inputIds0, int32_t const* inputIds1, int32_t const* inputIds2, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -172,7 +170,6 @@ __global__ void embLayerNormKernelMTron_4(int32_t ld, int32_t const* inputIds1, int32_t const* inputIds2, int32_t const* inputIds3, - int32_t nbLookupTables, float const* beta, float const* gamma, T const* mIdsEmbDev0, @@ -280,7 +277,6 @@ int32_t embSkipLayerNormMTron_2(cudaStream_t stream, <<>>(ld, inputIds0, inputIds1, - nbLookupTables, beta, gamma, mIdsEmbDev0, @@ -320,7 +316,6 @@ int32_t embSkipLayerNormMTron_3(cudaStream_t stream, inputIds0, inputIds1, inputIds2, - nbLookupTables, beta, gamma, mIdsEmbDev0, @@ -366,7 +361,6 @@ int32_t embSkipLayerNormMTron_4(cudaStream_t stream, inputIds1, inputIds2, inputIds3, - nbLookupTables, beta, gamma, mIdsEmbDev0, diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu new file mode 100644 index 0000000000000..c5082e9f851dc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.cu @@ -0,0 +1,497 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h" +#include +#include +#include +#include "NvInfer.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +constexpr size_t threadsPerCta128 = 2 * 2 * 32; +constexpr size_t threadsPerCta256 = 1 * 4 * 32; +constexpr size_t threadsPerCta384 = 1 * 8 * 32; +// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M +// dimension: (s + 16*warps_m - 1) / (16*warps_m); +constexpr size_t xmmasM128 = 4; +constexpr size_t xmmasM256 = 16; +constexpr size_t xmmasM384 = 24; +// Packed mask size per batch. Layout is XMMAS_M * THREADS_PER_CTA. +constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; +constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; +constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; +char const* EMB_LAYER_NORM_VERSION{"1"}; +char const* EMB_LAYER_NORM_NAME{"ManyEmbLayerNormPluginDynamic"}; +// Static class fields initialization +nvinfer1::PluginFieldCollection EmbLayerNormPluginCreator::mFC{}; +std::vector EmbLayerNormPluginCreator::mPluginAttributes; + +EmbLayerNormPlugin::EmbLayerNormPlugin( + std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& IdsEmb) + : mLayerName(name), + mLd(beta.count), + mType(type), + mIdsEmb_(IdsEmb), + nbLookupTables_(static_cast(IdsEmb.size())) { + // Assuming Weights.count is the number of elements and not bytes + assert(beta.count == gamma.count); + mBeta.convertAndCopy(beta, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(gamma, nvinfer1::DataType::kFLOAT); + copyToDevice(&mGamma, sizeof(float) * mGamma.count, &mGammaDev); + copyToDevice(&mBeta, sizeof(float) * mBeta.count, &mBetaDev); + for (size_t i = 0; i < mIdsEmb_.size(); ++i) { + assert(mIdsEmb_[i].count % mLd == 0); + mIdsVocabSize.push_back(int32_t(mIdsEmb_[i].count / mLd)); + WeightsWithOwnership tem_weight; + tem_weight.convertAndCopy(mIdsEmb_[i], mType); + void* cudaMem{nullptr}; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMalloc(&cudaMem, getWeightsSize(tem_weight, mType))); + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(cudaMem, + tem_weight.values, + getWeightsSize(tem_weight, mType), + cudaMemcpyHostToDevice)); + mIdsEmbPtrs.push_back(cudaMem); + } +} + +EmbLayerNormPlugin::EmbLayerNormPlugin(std::string const& name, + void const* data, + size_t length) + : mLayerName(name), + mGammaDev(nullptr), + mBetaDev(nullptr), + mIdsEmbPtrs{}, + mIdsEmb_{} { + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mLd); + deserialize_value(&data, &length, &nbLookupTables_); + for (int32_t i = 0; i < nbLookupTables_; ++i) { + int32_t tem; + deserialize_value(&data, &length, &tem); + mIdsVocabSize.push_back(tem); + } + char const* d = static_cast(data); + mBeta.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); + mGamma.convertAndCopy(&d, mLd, nvinfer1::DataType::kFLOAT); + for (int32_t i = 0; i < nbLookupTables_; ++i) { + nvinfer1::Weights pre_tem_weight; + pre_tem_weight.type = mType; + pre_tem_weight.count = mLd * size_t(mIdsVocabSize[i]); + const auto nbBytes = mLd * size_t(mIdsVocabSize[i]) * getElementSize(mType); + auto destBuf = new char[nbBytes]; + pre_tem_weight.values = destBuf; + std::copy_n(d, nbBytes, destBuf); + d += nbBytes; + mIdsEmb_.push_back(pre_tem_weight); + } +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* EmbLayerNormPlugin::clone() const noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin clone"); + auto p = new EmbLayerNormPlugin(mLayerName, mType, mBeta, mGamma, mIdsEmb_); + p->setPluginNamespace(mNamespace.c_str()); + return p; +} + +nvinfer1::DimsExprs EmbLayerNormPlugin::getOutputDimensions( + int32_t outputIndex, + nvinfer1::DimsExprs const* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + assert(outputIndex == 0); + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mLd); + return ret; +} + +bool EmbLayerNormPlugin::supportsFormatCombination( + int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept { + assert(nbOutputs == 1); + nvinfer1::PluginTensorDesc const& prev = inOut[0]; + nvinfer1::PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != nvinfer1::TensorFormat::kLINEAR) { + return false; + } + if (pos == 0) { + return desc.type == nvinfer1::DataType::kINT32; + } + if (0 < pos && pos < nbInputs) { + assert(desc.dims.nbDims == prev.dims.nbDims); + for (int i = 0; i < prev.dims.nbDims; ++i) { + assert(desc.dims.d[i] == prev.dims.d[i]); + } + return desc.type == prev.type; + } + if (pos == nbInputs) { // output + return desc.type == mType && desc.dims.nbDims == 3 && + desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1]; + } +} + +void EmbLayerNormPlugin::configurePlugin( + nvinfer1::DynamicPluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* outputs, + int32_t nbOutputs) noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin configurePlugin"); + assert(static_cast(outputs[0].desc.dims.d[2]) == + static_cast(mLd)); + int32_t const B = inputs[0].desc.dims.d[0]; + if (B > 0) { + assert(outputs[0].desc.dims.d[0] == B); + } + assert(outputs[0].desc.type == mType); +} + +size_t EmbLayerNormPlugin::getWorkspaceSize( + nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t EmbLayerNormPlugin::enqueue( + nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { + int32_t batchSize = inputDesc[0].dims.d[0]; + int32_t const maxSeqlen = inputDesc[0].dims.d[1]; + if (maxSeqlen > 512) { + PADDLE_THROW(platform::errors::InvalidArgument( + "EmbLayerNormPlugin support maxSeqlen is 512")); + } + const float* beta = mBetaDev.get(); + const float* gamma = mGammaDev.get(); + if (mType == nvinfer1::DataType::kFLOAT) { + auto output = static_cast(outputs[0]); + if (nbLookupTables_ == 2) { + return embSkipLayerNorm_2( + stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNorm_3( + stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNorm_4( + stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } + } else if (mType == nvinfer1::DataType::kHALF) { + auto output = static_cast(outputs[0]); + if (nbLookupTables_ == 2) { + return embSkipLayerNorm_2(stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + mIdsVocabSize[0], + mIdsVocabSize[1], + output); + } else if (nbLookupTables_ == 3) { + return embSkipLayerNorm_3(stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + output); + } else if (nbLookupTables_ == 4) { + return embSkipLayerNorm_4(stream, + static_cast(mLd), + batchSize, + maxSeqlen, + static_cast(inputs[0]), + static_cast(inputs[1]), + static_cast(inputs[2]), + static_cast(inputs[3]), + nbLookupTables_, + beta, + gamma, + static_cast(mIdsEmbPtrs[0]), + static_cast(mIdsEmbPtrs[1]), + static_cast(mIdsEmbPtrs[2]), + static_cast(mIdsEmbPtrs[3]), + mIdsVocabSize[0], + mIdsVocabSize[1], + mIdsVocabSize[2], + mIdsVocabSize[3], + output); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support 2,3,4 lookup_tables fused ")); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported type error, expected [kHALF,kFLOAT]")); + } + return STATUS_SUCCESS; +} + +// IPluginV2Ext Methods +nvinfer1::DataType EmbLayerNormPlugin::getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept { + assert(index == 0); + assert(mType == nvinfer1::DataType::kHALF || + mType == nvinfer1::DataType::kFLOAT); + return mType; +} + +// IPluginV2 Methods +char const* EmbLayerNormPlugin::getPluginType() const noexcept { + return EMB_LAYER_NORM_NAME; +} + +char const* EmbLayerNormPlugin::getPluginVersion() const noexcept { + return EMB_LAYER_NORM_VERSION; +} + +int32_t EmbLayerNormPlugin::getNbOutputs() const noexcept { return 1; } + +int32_t EmbLayerNormPlugin::initialize() noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin initialize"); + return 0; +} + +void EmbLayerNormPlugin::terminate() noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormPlugin terminate"); +} + +size_t EmbLayerNormPlugin::getSerializationSize() const noexcept { + size_t const wordSize = getElementSize(mType); + return 2 * sizeof(float) * mLd // beta + gamma + + sizeof(mType) // + + sizeof(mLd) // + + mIdsVocabSize.size() * sizeof(mIdsVocabSize[0]) // + + wordSize * mLd * + accumulate( + mIdsVocabSize.begin(), mIdsVocabSize.end(), 0) // ids emb + + sizeof(nbLookupTables_); // numbers of lookup_table +} + +void EmbLayerNormPlugin::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mLd); + serialize_value(&buffer, nbLookupTables_); + for (size_t i = 0; i < mIdsVocabSize.size(); ++i) { + serialize_value(&buffer, mIdsVocabSize[i]); + } + char* d = static_cast(buffer); + size_t const wordSize = getElementSize(mType); + serFromDev(&d, mBetaDev.get(), mLd); + serFromDev(&d, mGammaDev.get(), mLd); + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { + serFromDev(&d, + static_cast(mIdsEmbPtrs[i]), + mLd * mIdsVocabSize[i] * wordSize); + } +} + +void EmbLayerNormPlugin::destroy() noexcept { + // This gets called when the network containing plugin is destroyed + mBetaDev.reset(nullptr); + mGammaDev.reset(nullptr); + for (size_t i = 0; i < mIdsEmbPtrs.size(); ++i) { + cudaFree(mIdsEmbPtrs[i]); + } + delete this; +} + +void EmbLayerNormPlugin::setPluginNamespace(char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* EmbLayerNormPlugin::getPluginNamespace() const noexcept { + return mNamespace.c_str(); +} + +EmbLayerNormPluginCreator::EmbLayerNormPluginCreator() {} + +char const* EmbLayerNormPluginCreator::getPluginName() const noexcept { + return EMB_LAYER_NORM_NAME; +} + +char const* EmbLayerNormPluginCreator::getPluginVersion() const noexcept { + return EMB_LAYER_NORM_VERSION; +} + +nvinfer1::PluginFieldCollection const* +EmbLayerNormPluginCreator::getFieldNames() noexcept { + return &mFC; +} + +bool initialize_fields(nvinfer1::PluginFieldCollection const* fc, + nvinfer1::Weights* beta, + nvinfer1::Weights* gamma, + std::vector* IdsEmb) { + bool output_fp16 = false; + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("bert_embeddings_layernorm_beta") == 0) { + TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_beta..."); + beta->values = fc->fields[i].data; + beta->count = fc->fields[i].length; + beta->type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("bert_embeddings_layernorm_gamma") == 0) { + TRANSFORMER_DEBUG_MSG("Building bert_embeddings_layernorm_gamma..."); + gamma->values = fc->fields[i].data; + gamma->count = fc->fields[i].length; + gamma->type = fieldTypeToDataType(fc->fields[i].type); + } + + if (field_name.compare("output_fp16") == 0) { + TRANSFORMER_DEBUG_MSG("Building output_fp16..."); + assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + } + if (field_name.compare("bert_embeddings_word_embeddings_" + + std::to_string(i - 3)) == 0) { + TRANSFORMER_DEBUG_MSG( + ("bert_embeddings_word_embeddings_" + std::to_string(i - 3)).c_str()); + nvinfer1::Weights tem; + tem.values = fc->fields[i].data; + tem.count = fc->fields[i].length; + tem.type = fieldTypeToDataType(fc->fields[i].type); + IdsEmb->push_back(tem); + } + } + return output_fp16; +} + +nvinfer1::IPluginV2* EmbLayerNormPluginCreator::createPlugin( + char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept { + TRANSFORMER_DEBUG_MSG("EmbLayerNormVar createPlugin"); + nvinfer1::Weights beta; + nvinfer1::Weights gamma; + std::vector IdsEmb; + bool output_fp16 = initialize_fields(fc, &beta, &gamma, &IdsEmb); + TRANSFORMER_DEBUG_MSG("Building the Plugin..."); + EmbLayerNormPlugin* p = new EmbLayerNormPlugin( + name, + output_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + beta, + gamma, + IdsEmb); + return p; +} + +nvinfer1::IPluginV2* EmbLayerNormPluginCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept { + return new EmbLayerNormPlugin(name, serialData, serialLength); +} + +void EmbLayerNormPluginCreator::setPluginNamespace( + char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* EmbLayerNormPluginCreator::getPluginNamespace() const noexcept { + return mNamespace.c_str(); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h new file mode 100644 index 0000000000000..a48287dc92cd9 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_plugin.h @@ -0,0 +1,203 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "NvInferPlugin.h" +#include "NvInferRuntime.h" + +#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +template +int32_t embSkipLayerNorm_2(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + int32_t, + int32_t, + T*); + +template +int32_t embSkipLayerNorm_3(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + T*); + +template +int32_t embSkipLayerNorm_4(cudaStream_t, + int32_t, + int32_t, + int32_t, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t const*, + int32_t, + float const*, + float const*, + T const*, + T const*, + T const*, + T const*, + int32_t, + int32_t, + int32_t, + int32_t, + T*); +class EmbLayerNormPlugin : public nvinfer1::IPluginV2DynamicExt { + public: + EmbLayerNormPlugin(std::string const& name, + nvinfer1::DataType const type, + nvinfer1::Weights const& beta, + nvinfer1::Weights const& gamma, + const std::vector& ids_emb); + + EmbLayerNormPlugin(std::string const& name, void const* data, size_t length); + + EmbLayerNormPlugin() = delete; + + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + + nvinfer1::DimsExprs getOutputDimensions( + int32_t outputIndex, + const nvinfer1::DimsExprs* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + + int32_t initialize() noexcept override; + + void terminate() noexcept override; + + char const* getPluginVersion() const noexcept override; + + bool supportsFormatCombination(int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override; + + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + + nvinfer1::DataType getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + char const* getPluginType() const noexcept override; + + int32_t getNbOutputs() const noexcept override; + + size_t getSerializationSize() const noexcept override; + + void serialize(void* buffer) const noexcept override; + + void destroy() noexcept override; + + char const* getPluginNamespace() const noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + protected: + std::string const mLayerName; + std::string mNamespace; + cuda_unique_ptr mGammaDev; + cuda_unique_ptr mBetaDev; + std::vector mIdsEmbPtrs; + size_t mLd; // leading dim = hidden size + std::vector mIdsVocabSize; + WeightsWithOwnership mBeta; + WeightsWithOwnership mGamma; + nvinfer1::DataType mType; + std::vector mIdsEmb_; + int32_t nbLookupTables_ = 0; +}; + +class EmbLayerNormPluginCreator : public nvinfer1::IPluginCreator { + public: + EmbLayerNormPluginCreator(); + + char const* getPluginName() const noexcept override; + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + nvinfer1::IPluginV2* createPlugin( + char const* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + + char const* getPluginVersion() const noexcept override; + + nvinfer1::IPluginV2* deserializePlugin(char const* name, + void const* serialData, + size_t serialLength) noexcept override; + + protected: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +REGISTER_TRT_PLUGIN_V2(EmbLayerNormPluginCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 6a8b39d1139d0..01abd3a61863f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -39,7 +39,8 @@ constexpr size_t packedMaskSize256 = xmmasM256 * threadsPerCta256; constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_HFACE{"1"}; char const* EMB_LAYER_NORM_VAR_SEQLEN_VERSION_MTRON{"2"}; -char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{"ManyEmbLayerNormPluginDynamic"}; +char const* EMB_LAYER_NORM_VAR_SEQLEN_NAME{ + "ManyEmbLayerNormVarlenPluginDynamic"}; // Static class fields initialization nvinfer1::PluginFieldCollection EmbLayerNormVarSeqlenPluginBaseCreator::mFC{}; std::vector @@ -167,7 +168,6 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions( assert(inputs[i].nbDims == inputs[1].nbDims); // same shape } assert(inputs[0].nbDims == 1); // pos_id: B+1 - assert(outputIndex == 0 || outputIndex == 1); if (outputIndex == 0) { nvinfer1::DimsExprs ret; ret.nbDims = 4; @@ -176,25 +176,32 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginHFace::getOutputDimensions( ret.d[2] = exprBuilder.constant(1); ret.d[3] = exprBuilder.constant(1); return ret; + } else if (outputIndex == 1) { + // This is a hack: we just report some mask size and rely the plugins to + // play nicely together. + // At runtime, depending on the actual maxSeqlen, the size might be + // different. + int32_t maskSize_ = packedMaskSize384; + auto maskSize = exprBuilder.constant(maskSize_); + auto fp16maskSize = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, + *maskSize, + *exprBuilder.constant(2)); + auto Bplus1 = inputs[0].d[0]; // pos_id + auto one = exprBuilder.constant(1); + auto B = exprBuilder.operation( + nvinfer1::DimensionOperation::kSUB, *Bplus1, *one); + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = B; + ret.d[1] = fp16maskSize; + return ret; + } else { + nvinfer1::DimsExprs ret; + ret.nbDims = 1; + ret.d[0] = inputs[nbInputs - 1].d[1]; // mask id: max seqlen + return ret; } - - // This is a hack: we just report some mask size and rely the plugins to play - // nicely together. - // At runtime, depending on the actual maxSeqlen, the size might be - // different. - int32_t maskSize_ = packedMaskSize384; - auto maskSize = exprBuilder.constant(maskSize_); - auto fp16maskSize = exprBuilder.operation( - nvinfer1::DimensionOperation::kPROD, *maskSize, *exprBuilder.constant(2)); - auto Bplus1 = inputs[0].d[0]; // pos_id - auto one = exprBuilder.constant(1); - auto B = - exprBuilder.operation(nvinfer1::DimensionOperation::kSUB, *Bplus1, *one); - nvinfer1::DimsExprs ret; - ret.nbDims = 2; - ret.d[0] = B; - ret.d[1] = fp16maskSize; - return ret; } nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions( @@ -209,14 +216,20 @@ nvinfer1::DimsExprs EmbLayerNormVarSeqlenPluginMTron::getOutputDimensions( assert(inputs[i].nbDims == inputs[1].nbDims); // same shape } assert(inputs[0].nbDims == 1); // pos_id: B+1 - assert(outputIndex == 0 || outputIndex == 1); - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[1].d[0]; - ret.d[1] = exprBuilder.constant(mLd); - ret.d[2] = exprBuilder.constant(1); - ret.d[3] = exprBuilder.constant(1); - return ret; + if (outputIndex == 0 || outputIndex == 1) { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[1].d[0]; // sum of seq length + ret.d[1] = exprBuilder.constant(mLd); + ret.d[2] = exprBuilder.constant(1); + ret.d[3] = exprBuilder.constant(1); + return ret; + } else { + nvinfer1::DimsExprs ret; + ret.nbDims = 1; + ret.d[0] = inputs[nbInputs - 1].d[1]; // mask id: max seqlen + return ret; + } } bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( @@ -224,7 +237,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, int32_t nbOutputs) noexcept { - assert(nbOutputs == 2); + assert(nbOutputs == 3); nvinfer1::PluginTensorDesc const& desc = inOut[pos]; if (desc.format != nvinfer1::TensorFormat::kLINEAR) { return false; @@ -241,8 +254,8 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( return desc.type == prev.type && desc.dims.nbDims == 1 && desc.dims.d[0] == prev.dims.d[0]; } - if (pos == nbInputs - 1) { // max seq length - return desc.dims.nbDims == 1; + if (pos == nbInputs - 1) { // mask id + return desc.type == prev.type; } // embedded sequence if (pos == nbInputs) { @@ -250,8 +263,14 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( desc.dims.d[0] == inOut[1].dims.d[0] && desc.dims.d[2] == 1 && desc.dims.d[3] == 1; } - // mask - return desc.type == nvinfer1::DataType::kHALF; + // mask(HFace) or pre_layernorm_bias(MTron) + if (pos == nbInputs + 1) { + return desc.type == prev.type; + } + // max seqlen + if (pos == nbInputs + 2) { + return desc.type == prev.type; + } } void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs, @@ -259,8 +278,7 @@ void checkConfigurationInputs(nvinfer1::DynamicPluginTensorDesc const* inputs, nvinfer1::DynamicPluginTensorDesc const* outputs, int32_t nbOutputs) noexcept { // Validate input arguments - // assert(nbInputs == 4); - assert(nbOutputs == 2); + assert(nbOutputs == 3); assert(inputs[0].desc.dims.nbDims == 1); assert(inputs[0].desc.type == nvinfer1::DataType::kINT32); for (int i = 1; i < nbInputs - 1; ++i) { @@ -671,7 +689,7 @@ char const* EmbLayerNormVarSeqlenPluginMTron::getPluginVersion() } int32_t EmbLayerNormVarSeqlenPluginBase::getNbOutputs() const noexcept { - return 2; + return 3; } int32_t EmbLayerNormVarSeqlenPluginHFace::initialize() noexcept { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h index 944f3abb9dba8..e75e054aeef2f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h @@ -194,7 +194,6 @@ class EmbLayerNormVarSeqlenPluginBase : public nvinfer1::IPluginV2DynamicExt { cuda_unique_ptr mGammaDev; cuda_unique_ptr mBetaDev; std::vector mIdsEmbPtrs; - // std::vector mIdsEmbDev; size_t mLd; // leading dim = hidden size std::vector mIdsVocabSize; WeightsWithOwnership mBeta; From a232615e1bb5d6f66b4b14438f3e62fb04e1cf01 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 1 Dec 2022 03:25:18 +0000 Subject: [PATCH 02/19] fix --- .../inference/tensorrt/convert/emb_eltwise_layernorm.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 62d98591f62ec..fcd93c217672d 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -206,13 +206,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { if (enable_int8) { output_fp16 = 1; } - PADDLE_ENFORCE_EQ( - output_fp16, - 1, - platform::errors::InvalidArgument( - "Only Precision::KHalf(fp16) is supported when infering " - "ernie(bert) model with config.EnableVarseqlen(). " - "But Precision::KFloat32 is setted.")); std::vector fields; std::vector temp_fields_keys; From 02d4aa2b18c0aea90c00f301228bfb56210f9c4a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 1 Dec 2022 14:50:54 +0000 Subject: [PATCH 03/19] fix --- .../trt_dynamic_shape_ernie_serialize_deserialize_test.h | 8 +++++--- .../inference/tests/api/trt_dynamic_shape_ernie_test.cc | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index aa533454fdfef..698137e8452f1 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -43,18 +43,18 @@ static void run(const AnalysisConfig& config, std::vector* out_data) { tmp_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len); - int64_t i0[run_seq_len] = { + int32_t i0[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int64_t i1[run_seq_len] = { + int32_t i1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int32_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; @@ -139,6 +139,8 @@ static void trt_ernie(bool with_fp16, std::vector result) { #if defined _WIN32 #else config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); + paddle_infer::experimental::InternalUtils::SetTransformerMaskid( + &config, "read_file_0.tmp_4"); #endif config.SetTRTDynamicShapeInfo( diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index aeefcf1059243..4b2555f356faa 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector* out_data, int bs) { const int run_seq_len = 128; size_t len = run_batch * run_seq_len; - int64_t i0_bs1[run_seq_len] = { + int32_t i0_bs1[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int64_t i1_bs1[run_seq_len] = { + int32_t i1_bs1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int64_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int32_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; @@ -133,6 +133,8 @@ void trt_ernie(bool with_fp16, config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, false); config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); + paddle_infer::experimental::InternalUtils::SetTransformerMaskid( + &config, "read_file_0.tmp_4"); std::vector out_data; run(config, &out_data, batch_size); From 44c17a18160d46638f54d29a6496dd11d24effc0 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 2 Dec 2022 06:28:25 +0000 Subject: [PATCH 04/19] fix --- .../framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc | 4 ++-- paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc | 4 ++-- paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 4 ++-- .../fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc | 1 - 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc index f870796a4c164..23ebbddf5796f 100644 --- a/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_embedding_eltwise_layernorm_fuse_pass.cc @@ -441,14 +441,14 @@ void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { std::string mask_id = Get("tensorrt_transformer_maskid"); if ((use_varseqlen && pos_id != "" && mask_id != "") || - (!use_varseqlen && pos_id == "" && mask_id == "")) { + (!use_varseqlen && pos_id == "")) { VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kEmbEltwiseLayernormPass, new bool(true)); diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 4ecc9919f5485..1d17cba445905 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -1637,14 +1637,14 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { "preln_embedding_eltwise_layernorm_fuse_" "pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kMultiheadMatmulPass, new bool(true)); diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index d33adab8b3ea7..2e578a06e38e1 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -207,14 +207,14 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { "trt_embedding_eltwise_layernorm_fuse_pass, " "trt_multihead_matmul_fuse_pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } } diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index fcd93c217672d..a70839ee9d401 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -87,7 +87,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { input_embs.push_back(weight.get()); emb_sizes.push_back(weight.get().count); } - // hidden = emb_dims[1]; } bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); From f6bbd887150a3c7c1b762f16a9eec858f2939fd7 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Fri, 2 Dec 2022 10:52:54 +0000 Subject: [PATCH 05/19] fix unitest --- .../trt_dynamic_shape_ernie_serialize_deserialize_test.h | 5 ----- .../tests/api/trt_dynamic_shape_transformer_prune_test.cc | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index 698137e8452f1..41de6cc0d3f89 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -135,14 +135,9 @@ static void trt_ernie(bool with_fp16, std::vector result) { if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } - -#if defined _WIN32 -#else config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); paddle_infer::experimental::InternalUtils::SetTransformerMaskid( &config, "read_file_0.tmp_4"); -#endif - config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); AnalysisConfig* config_deser = new AnalysisConfig(config); diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc index 937303b595e13..e96e77e8038fb 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc @@ -33,15 +33,15 @@ void run(const AnalysisConfig& config, std::vector* out_data) { tmp_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len); - int64_t i0[run_seq_len] = { + int32_t i0[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int64_t i1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int32_t i1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; - int64_t i2[run_seq_len] = { + int32_t i2[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, From db05e1cf746644ae09174ae92af21f46ff45d4b2 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Sun, 4 Dec 2022 13:01:48 +0000 Subject: [PATCH 06/19] fix --- paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 1d17cba445905..57cc9443850a1 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -1168,14 +1168,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { "preln_embedding_eltwise_layernorm_fuse_" "pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kMultiheadMatmulPass, new bool(true)); From c9c9dc23316263bee034d4067fddc2e1b15552d3 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Dec 2022 02:59:58 +0000 Subject: [PATCH 07/19] fix: --- .../framework/ir/trt_multihead_matmul_fuse_pass.cc | 4 ++-- ...dynamic_shape_ernie_serialize_deserialize_test.h | 13 ++++++++----- .../tests/api/trt_dynamic_shape_ernie_test.cc | 8 +++----- .../api/trt_dynamic_shape_transformer_prune_test.cc | 6 +++--- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 57cc9443850a1..1d17cba445905 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -1168,14 +1168,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { "preln_embedding_eltwise_layernorm_fuse_" "pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "") { + } else if (!use_varseqlen && pos_id == "" && mask_id == "") { VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id. Please " + "pos_id, set mask_id. Please " "reconfig")); } graph->Set(kMultiheadMatmulPass, new bool(true)); diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index 41de6cc0d3f89..aa533454fdfef 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -43,18 +43,18 @@ static void run(const AnalysisConfig& config, std::vector* out_data) { tmp_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len); - int32_t i0[run_seq_len] = { + int64_t i0[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int32_t i1[run_seq_len] = { + int64_t i1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int32_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; @@ -135,9 +135,12 @@ static void trt_ernie(bool with_fp16, std::vector result) { if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } + +#if defined _WIN32 +#else config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); - paddle_infer::experimental::InternalUtils::SetTransformerMaskid( - &config, "read_file_0.tmp_4"); +#endif + config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); AnalysisConfig* config_deser = new AnalysisConfig(config); diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index 4b2555f356faa..aeefcf1059243 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector* out_data, int bs) { const int run_seq_len = 128; size_t len = run_batch * run_seq_len; - int32_t i0_bs1[run_seq_len] = { + int64_t i0_bs1[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int32_t i1_bs1[run_seq_len] = { + int64_t i1_bs1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int32_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int64_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; @@ -133,8 +133,6 @@ void trt_ernie(bool with_fp16, config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, false); config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); - paddle_infer::experimental::InternalUtils::SetTransformerMaskid( - &config, "read_file_0.tmp_4"); std::vector out_data; run(config, &out_data, batch_size); diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc index e96e77e8038fb..937303b595e13 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc @@ -33,15 +33,15 @@ void run(const AnalysisConfig& config, std::vector* out_data) { tmp_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len); - int32_t i0[run_seq_len] = { + int64_t i0[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int32_t i1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int64_t i1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; - int32_t i2[run_seq_len] = { + int64_t i2[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, From c2eab7b9d6e3aeda407af8e7412b439a8c809aa6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Dec 2022 03:01:28 +0000 Subject: [PATCH 08/19] fix --- paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc index 1d17cba445905..57cc9443850a1 100644 --- a/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_multihead_matmul_fuse_pass.cc @@ -1168,14 +1168,14 @@ void TrtMultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { "preln_embedding_eltwise_layernorm_fuse_" "pass. please use no_varseqlen")); } - } else if (!use_varseqlen && pos_id == "" && mask_id == "") { + } else if (!use_varseqlen && pos_id == "") { VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; } else { PADDLE_THROW( platform::errors::Fatal("Use transformer'varseqlen need config: " "use_varseqlen, set pos_id, set " "mask_id. Or not use varseqlen, do not set " - "pos_id, set mask_id. Please " + "pos_id. Please " "reconfig")); } graph->Set(kMultiheadMatmulPass, new bool(true)); From 7cc2d6db8b224c41c5bcd11abfa8294d7c2d63b6 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Dec 2022 06:13:04 +0000 Subject: [PATCH 09/19] fix --- .../trt_dynamic_shape_ernie_serialize_deserialize_test.h | 8 ++++---- .../inference/tests/api/trt_dynamic_shape_ernie_test.cc | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index aa533454fdfef..78511f963c028 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -38,23 +38,23 @@ static void run(const AnalysisConfig& config, std::vector* out_data) { int run_batch = 1; const int run_seq_len = 128; - std::vector tmp_input; + std::vector tmp_input; std::vector tmp_four_input; tmp_input.reserve(run_batch * run_seq_len); tmp_four_input.reserve(run_batch * run_seq_len); - int64_t i0[run_seq_len] = { + int32_t i0[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int64_t i1[run_seq_len] = { + int32_t i1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int32_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index aeefcf1059243..a32d2161d9519 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -33,18 +33,18 @@ void run(const AnalysisConfig& config, std::vector* out_data, int bs) { const int run_seq_len = 128; size_t len = run_batch * run_seq_len; - int64_t i0_bs1[run_seq_len] = { + int32_t i0_bs1[run_seq_len] = { 1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321, 4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2, 75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2}; - int64_t i1_bs1[run_seq_len] = { + int32_t i1_bs1[run_seq_len] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - int64_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + int32_t i2_bs1[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; @@ -52,7 +52,7 @@ void run(const AnalysisConfig& config, std::vector* out_data, int bs) { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; - std::vector i0_data(len), i1_data(len), i2_data(len); + std::vector i0_data(len), i1_data(len), i2_data(len); std::vector i3_data(len); for (size_t i = 0; i < len; i++) { From 04eb453585e864ea2fbf168ec90de0ee4fd3d6fb Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Dec 2022 08:35:09 +0000 Subject: [PATCH 10/19] fix --- .../api/trt_dynamic_shape_ernie_serialize_deserialize_test.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index 78511f963c028..c178edde7abf8 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -135,12 +135,7 @@ static void trt_ernie(bool with_fp16, std::vector result) { if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } - -#if defined _WIN32 -#else config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); -#endif - config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); AnalysisConfig* config_deser = new AnalysisConfig(config); From 50d30d39391334f219a385fb5b5b2f966e28d53f Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 5 Dec 2022 15:12:12 +0000 Subject: [PATCH 11/19] fix --- paddle/fluid/inference/tests/api/CMakeLists.txt | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 7a281ddfcdf6a..fe2e6807f8112 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1264,8 +1264,12 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(trt_quant_int8_yolov3_r50_test PROPERTIES TIMEOUT 400) set_tests_properties(trt_resnet50_test PROPERTIES TIMEOUT 300) set_tests_properties(trt_cascade_rcnn_test PROPERTIES TIMEOUT 300) - set_tests_properties(test_trt_dynamic_shape_ernie_ser_deser PROPERTIES TIMEOUT - 300) + if(WIN32) + + else() + set_tests_properties(test_trt_dynamic_shape_ernie_ser_deser + PROPERTIES TIMEOUT 300) + endif() set_tests_properties(test_trt_dynamic_shape_ernie_fp16_ser_deser PROPERTIES TIMEOUT 300) set_tests_properties(test_trt_dynamic_shape_ernie PROPERTIES TIMEOUT 300) From 1595f3946319bd2e157dd8ee920d6e35c418252a Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 02:15:03 +0000 Subject: [PATCH 12/19] fix --- paddle/fluid/framework/ir/CMakeLists.txt | 2 +- .../fluid/inference/api/analysis_predictor.cc | 3 -- .../inference/api/paddle_pass_builder.cc | 31 ++++++++----------- .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../inference/tensorrt/plugin/CMakeLists.txt | 2 +- .../fluid/inference/tests/api/CMakeLists.txt | 8 ++--- 6 files changed, 18 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 06ea7acb3315e..a19272efa2b47 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -136,7 +136,7 @@ if(WITH_TENSORRT) pass_library(preln_layernorm_x_fuse_pass inference) endif() -if(WITH_TENSORRT AND NOT WIN32) +if(WITH_TENSORRT) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) endif() diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c1ca6d8e9608c..e670295682453 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2305,11 +2305,8 @@ USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) -#ifdef _WIN32 -#else USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); -#endif USE_TRT_CONVERTER(preln_skip_layernorm) USE_TRT_CONVERTER(preln_residual_bias) USE_TRT_CONVERTER(c_allreduce_sum) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c964ce7e4d0d2..c39226f6bb86f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -95,26 +95,21 @@ const std::vector kTRTSubgraphPasses({ "identity_scale_op_clean_pass", // "add_support_int8_pass", // // "fc_fuse_pass", // - "simplify_with_basic_ops_pass", // - -#if defined _WIN32 -#else + "simplify_with_basic_ops_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", // -#endif - - "delete_c_identity_op_pass", // - "trt_multihead_matmul_fuse_pass_v2", // - "trt_multihead_matmul_fuse_pass_v3", // - "multihead_matmul_roformer_fuse_pass", // - "constant_folding_pass", // - "vit_attention_fuse_pass", // - "trt_skip_layernorm_fuse_pass", // - "preln_skip_layernorm_fuse_pass", // - "layernorm_shift_partition_fuse_pass", // - "merge_layernorm_fuse_pass", // - "preln_residual_bias_fuse_pass", // - "preln_layernorm_x_fuse_pass", // + "delete_c_identity_op_pass", // + "trt_multihead_matmul_fuse_pass_v2", // + "trt_multihead_matmul_fuse_pass_v3", // + "multihead_matmul_roformer_fuse_pass", // + "constant_folding_pass", // + "vit_attention_fuse_pass", // + "trt_skip_layernorm_fuse_pass", // + "preln_skip_layernorm_fuse_pass", // + "layernorm_shift_partition_fuse_pass", // + "merge_layernorm_fuse_pass", // + "preln_residual_bias_fuse_pass", // + "preln_layernorm_x_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 070e7c2c0fd8e..b285265893717 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -90,7 +90,7 @@ list( fused_lookup_tables_op.cc expand_v2_op.cc) -if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) +if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc preln_emb_eltwise_layernorm.cc) endif() diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 40f9ef127f5cc..08c167d6d501d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -40,7 +40,7 @@ list( many_emb_layernorm_plugin.cu many_emb_Layernorm_kernel.cu) -if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) +if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu many_emb_Layernorm_varseqlen_kernelMTron.cu many_emb_Layernorm_varseqlen_kernelHFace.cu) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index fe2e6807f8112..7a281ddfcdf6a 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1264,12 +1264,8 @@ if(WITH_GPU AND TENSORRT_FOUND) set_tests_properties(trt_quant_int8_yolov3_r50_test PROPERTIES TIMEOUT 400) set_tests_properties(trt_resnet50_test PROPERTIES TIMEOUT 300) set_tests_properties(trt_cascade_rcnn_test PROPERTIES TIMEOUT 300) - if(WIN32) - - else() - set_tests_properties(test_trt_dynamic_shape_ernie_ser_deser - PROPERTIES TIMEOUT 300) - endif() + set_tests_properties(test_trt_dynamic_shape_ernie_ser_deser PROPERTIES TIMEOUT + 300) set_tests_properties(test_trt_dynamic_shape_ernie_fp16_ser_deser PROPERTIES TIMEOUT 300) set_tests_properties(test_trt_dynamic_shape_ernie PROPERTIES TIMEOUT 300) From a3b4bea859290e7cd184383744720d9c86983c9d Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 03:22:22 +0000 Subject: [PATCH 13/19] fix --- .../fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index a32d2161d9519..31d349970f0d1 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -130,6 +130,8 @@ void trt_ernie(bool with_fp16, if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } + paddle_infer::experimental::InternalUtils::SetTransformerMaskid( + &config, "read_file_0.tmp_4"); config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, false); config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); From 42b4ef8457d94641d03468a75892f0fb4a977513 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 04:11:46 +0000 Subject: [PATCH 14/19] fix --- paddle/fluid/inference/api/paddle_pass_builder.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c39226f6bb86f..5426000fa7c13 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -110,6 +110,7 @@ const std::vector kTRTSubgraphPasses({ "merge_layernorm_fuse_pass", // "preln_residual_bias_fuse_pass", // "preln_layernorm_x_fuse_pass", // + "reverse_roll_fuse_pass", // // "set_transformer_input_convert_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // From 0f8c60da973e3fcac49d6348c09ebe5adf888c7b Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 06:39:43 +0000 Subject: [PATCH 15/19] fix --- .../api/trt_dynamic_shape_ernie_serialize_deserialize_test.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h index c178edde7abf8..b50826953386c 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.h @@ -135,6 +135,7 @@ static void trt_ernie(bool with_fp16, std::vector result) { if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } + config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); From 9e726cad4756bb199b1d09540d1f63cf69906bdf Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 08:03:04 +0000 Subject: [PATCH 16/19] fix --- .../inference/api/paddle_pass_builder.cc | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 5426000fa7c13..06cdd5a91716d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -111,18 +111,17 @@ const std::vector kTRTSubgraphPasses({ "preln_residual_bias_fuse_pass", // "preln_layernorm_x_fuse_pass", // "reverse_roll_fuse_pass", // - // "set_transformer_input_convert_pass", // - "conv_bn_fuse_pass", // - "unsqueeze2_eltwise_fuse_pass", // - "trt_squeeze2_matmul_fuse_pass", // - "trt_flatten2_matmul_fuse_pass", // - "trt_map_matmul_v2_to_mul_pass", // - "trt_map_matmul_v2_to_matmul_pass", // - "trt_map_matmul_to_mul_pass", // - "fc_fuse_pass", // - "conv_elementwise_add_fuse_pass", // - "remove_padding_recover_padding_pass", // - "delete_remove_padding_recover_padding_pass", // + "conv_bn_fuse_pass", // + "unsqueeze2_eltwise_fuse_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // + "fc_fuse_pass", // + "conv_elementwise_add_fuse_pass", // + "remove_padding_recover_padding_pass", // + "delete_remove_padding_recover_padding_pass", // // "yolo_box_fuse_pass", // "dense_fc_to_sparse_pass", // "dense_multihead_matmul_to_sparse_pass", // From b1bd464106f9d715f62fa7384ff03c2c6486b64e Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Dec 2022 15:17:25 +0000 Subject: [PATCH 17/19] fix --- .../fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index 31d349970f0d1..a32d2161d9519 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -130,8 +130,6 @@ void trt_ernie(bool with_fp16, if (with_fp16) { precision = AnalysisConfig::Precision::kHalf; } - paddle_infer::experimental::InternalUtils::SetTransformerMaskid( - &config, "read_file_0.tmp_4"); config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, false); config.SetTRTDynamicShapeInfo( min_input_shape, max_input_shape, opt_input_shape); From 3ed6d727bcc8fdb248f484557ff1b89766cefc34 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 7 Dec 2022 06:26:37 +0000 Subject: [PATCH 18/19] fix --- .../api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc index f269432d4da1e..56da226e2739a 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_serialize_deserialize_test.cc @@ -28,11 +28,13 @@ limitations under the License. */ namespace paddle { namespace inference { - +#if defined _WIN32 +#else TEST(AnalysisPredictor, no_fp16) { std::vector result = {0.597841, 0.219972, 0.182187}; trt_ernie(false, result); } +#endif } // namespace inference } // namespace paddle From c3fe15687d3499ac19beff84e24890a71da6cbe7 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 8 Dec 2022 01:28:20 +0000 Subject: [PATCH 19/19] fix --- paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt | 4 ++-- ...lHFace.cu => many_emb_Layernorm_varseqlen_kernel_hface.cu} | 0 ...lMTron.cu => many_emb_Layernorm_varseqlen_kernel_mtron.cu} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename paddle/fluid/inference/tensorrt/plugin/{many_emb_Layernorm_varseqlen_kernelHFace.cu => many_emb_Layernorm_varseqlen_kernel_hface.cu} (100%) rename paddle/fluid/inference/tensorrt/plugin/{many_emb_Layernorm_varseqlen_kernelMTron.cu => many_emb_Layernorm_varseqlen_kernel_mtron.cu} (100%) diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 08c167d6d501d..74059453cddf2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -42,8 +42,8 @@ list( if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu - many_emb_Layernorm_varseqlen_kernelMTron.cu - many_emb_Layernorm_varseqlen_kernelHFace.cu) + many_emb_Layernorm_varseqlen_kernel_mtron.cu + many_emb_Layernorm_varseqlen_kernel_hface.cu) endif() if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_hface.cu similarity index 100% rename from paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu rename to paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_hface.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_mtron.cu similarity index 100% rename from paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu rename to paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernel_mtron.cu