From 3df38f5cdd0866c1e78f1c2674d3d6cf3166d35f Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Fri, 10 Jan 2020 10:32:22 +0800 Subject: [PATCH] [cherry-pick] Add FC padding, ernie test unit and layernorm parallel (#22198) * Optimize the kernel implementation of layernorm with openmp (#20895) * Add ernie c++ inference test (#21015) * Add ernie unit test test=develop * Add ernie unit test test=develop * Add ernie unit test test=develop * remove ngraph * optimize gpu test test=develop * optimize codes test=develop * fix cmake fails on inference_download_and_uncompress (#21185) * solve cmake fails on inference_download_and_uncompress test=develop * solve cmake fails on inference_download_and_uncompress test=develop * Add fc padding to improve mkl GEMM's performance when N and K are multiple of 128. (#20972) * Add fc padding to solve mkl performance test=develop * fix gpu pass and error information test=develop * fix fc_fuse_pass_test test=develop * fix error information test=develop * fix error information test=develop * fix name and add fc op padding test test=develop * fix attributes test=develop * optimize fc padding test=develop * fix test test=develop * Polish the codes of fc when needs padding (#21378) test=develop * Add ernie large c++ inference test (#21365) * add ernie-large test test=develop * add ernie large c++ inference test test=develop * Modify padding strategy: remove weight copy in fc padding (#21650) test=develop * optimize fc jit (#21878) test=develop Co-authored-by: Yihua Xu --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 29 ++ .../fluid/framework/ir/fc_fuse_pass_tester.cc | 20 ++ .../fluid/inference/tests/api/CMakeLists.txt | 37 +-- .../tests/api/analyzer_bert_tester.cc | 2 - .../tests/api/analyzer_ernie_tester.cc | 251 ++++++++++++++++++ .../fluid/inference/tests/api/tester_helper.h | 1 + paddle/fluid/operators/fc_op.cc | 34 ++- paddle/fluid/operators/fc_op.h | 26 +- .../jit/more/intrinsic/layer_norm.cc | 218 +++++++-------- paddle/fluid/operators/math/fc.cc | 59 ++-- paddle/fluid/operators/math/fc.cu | 7 +- paddle/fluid/operators/math/fc.h | 3 +- paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc | 7 +- .../fluid/tests/unittests/test_fc_op.py | 7 + 14 files changed, 545 insertions(+), 156 deletions(-) create mode 100644 paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index b53e6a250ced5..fd6930162dbaf 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -89,6 +89,35 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { std::string activation_type = with_relu ? "relu" : ""; desc.SetAttr("activation_type", activation_type); + // This is to add padding for dimension 128 on concern of MKL performance + auto* scope = param_scope(); + auto* weight = scope->FindVar(w->Name())->GetMutable(); + auto place = weight->place(); + bool use_gpu = Get("use_gpu"); + auto* weight_data = weight->data(); + auto weight_dims = weight->dims(); + int weight_num = product(weight_dims); + int w_h = weight_dims[0]; + int w_w = weight_dims[1]; + if (!use_gpu) { + if (w_h % 128 == 0 && w_w % 128 == 0) { + auto* weight_data_tmp = new float[weight_num]; + for (int i = 0; i < w_h; i++) { + memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, + w_w * sizeof(float)); + } + weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); + auto* weight_data_new = + weight->mutable_data(platform::CPUPlace()); + for (int i = 0; i < w_h; i++) { + memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, + w_w * sizeof(float)); + } + delete[] weight_data_tmp; + desc.SetAttr("padding_weights", true); + } + } + // For anakin subgraph int8 // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 320d28f131f03..dfae572d4634e 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -21,6 +21,24 @@ namespace paddle { namespace framework { namespace ir { +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "conv2d_filters_0", {}); + AddVarToScope(param_scope, "conv2d_bias_0", {}); + AddVarToScope(param_scope, "weights_0", {}); + AddVarToScope(param_scope, "weights_1", {}); + AddVarToScope(param_scope, "bias_1", {}); + AddVarToScope(param_scope, "bias_2", {}); + return param_scope; +} + TEST(FCFusePass, basic) { // inputs operator output // -------------------------------------------------------- @@ -50,6 +68,8 @@ TEST(FCFusePass, basic) { std::unique_ptr graph(new ir::Graph(layers.main_program())); auto pass = PassRegistry::Instance().Get("fc_fuse_pass"); + pass->Set("use_gpu", new bool(true)); + graph->Set("__param_scope__", CreateParamScope()); int num_nodes_before = graph->Nodes().size(); int num_mul_nodes_before = GetNumOpNodes(graph, "mul"); VLOG(3) << DebugString(graph); diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index f8a61b46de4c2..3342c6dfef9d9 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -27,10 +27,14 @@ function(download_model_and_data install_dir model_name data_name) download_data(${install_dir} ${data_name}) endfunction() +function(download_result install_dir result_name) + download_data(${install_dir} ${result_name}) +endfunction() + function(inference_analysis_api_test target install_dir filename) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark - ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) + ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt) endfunction() function(inference_analysis_api_test_build TARGET_NAME filename) @@ -72,13 +76,6 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary --disable_mkldnn_fc=${disable_fc}) endfunction() -function(inference_analysis_api_test_with_refer_result target install_dir filename) - inference_analysis_test(${target} SRCS ${filename} - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt - --refer_result=${install_dir}/result.txt) -endfunction() - function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path) inference_analysis_test_run(${TARGET_NAME} COMMAND ${test_binary} @@ -147,6 +144,20 @@ set(PYRAMID_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/pyramid_dnn") download_model_and_data(${PYRAMID_DNN_INSTALL_DIR} "PyramidDNN_model.tar.gz" "PyramidDNN_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR} analyzer_pyramid_dnn_tester.cc) +#Ernie +set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") +download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" "Ernie_data.txt.tar.gz" "Ernie_result.txt.tar.gz") +download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz") +inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc) + +#Ernie large +set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large") +download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" "Ernie_large_data.txt.tar.gz" "Ernie_large_result.txt.tar.gz") +download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz") +inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark + ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true) + # text_classification set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification") download_model_and_data(${TEXT_CLASSIFICATION_INSTALL_DIR} "text-classification-Senta.tar.gz" "text_classification_data.txt.tar.gz") @@ -170,14 +181,14 @@ set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") if (NOT EXISTS ${OCR_INSTALL_DIR}) inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz") endif() -inference_analysis_api_test_with_refer_result(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) +inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc) # mobilenet with transpose op set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet") if (NOT EXISTS ${MOBILENET_INSTALL_DIR}) inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz") endif() -inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc) +inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc) ### Image classification tests with fake data set(IMG_CLASS_TEST_APP "test_analyzer_image_classification") @@ -334,13 +345,9 @@ inference_analysis_test(test_analyzer_capi SRCS analyzer_capi_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c ARGS --infer_model=${RESNET50_MODEL_DIR}/model) -set(CAPI_MODEL_INSTALL_PD_DIR "${INFERENCE_DEMO_INSTALL_DIR}/capi_mobilenet") -if (NOT EXISTS ${CAPI_MODEL_INSTALL_PD_DIR}) - inference_download_and_uncompress(${CAPI_MODEL_INSTALL_PD_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz") -endif() inference_analysis_test(test_analyzer_capi_pd_tensor SRCS analyzer_capi_pd_tensor_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c - ARGS --infer_model=${CAPI_MODEL_INSTALL_PD_DIR}/model) + ARGS --infer_model=${MOBILENET_INSTALL_DIR}/model) if(WITH_MKLDNN) inference_analysis_test(test_analyzer_capi_int SRCS analyzer_capi_int_tester.cc diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index f679e1221821a..5035f9b358718 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -153,7 +153,6 @@ void profile(bool use_mkldnn = false, bool use_ngraph = false) { if (use_mkldnn) { config.EnableMKLDNN(); - config.pass_builder()->AppendPass("fc_mkldnn_pass"); } if (use_ngraph) { @@ -193,7 +192,6 @@ void compare(bool use_mkldnn = false, bool use_ngraph = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); } if (use_ngraph) { diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc new file mode 100644 index 0000000000000..199eee02d75c1 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc @@ -0,0 +1,251 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false, + bool use_gpu = false) { + cfg->SetModel(FLAGS_infer_model); + if (use_mkldnn) { + cfg->EnableMKLDNN(); + } + if (use_gpu) { + cfg->EnableUseGpu(100, 0); + } else { + cfg->DisableGpu(); + } + cfg->SwitchSpecifyInputNames(); + cfg->SwitchIrOptim(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); +} + +void profile(bool use_mkldnn = false, bool use_gpu = false) { + AnalysisConfig config; + SetConfig(&config, use_mkldnn, use_gpu); + + std::vector> outputs; + std::vector> inputs; + LoadInputData(&inputs); + TestPrediction(reinterpret_cast(&config), + inputs, &outputs, FLAGS_num_threads); +} + +TEST(Analyzer_ernie, profile) { profile(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_ernie, profile_mkldnn) { profile(true, false); } +#endif + +// Check the model by gpu +#ifdef PADDLE_WITH_CUDA +TEST(Analyzer_ernie, profile_gpu) { profile(false, true); } +#endif + +// Check the fuse status +TEST(Analyzer_Ernie, fuse_statis) { + AnalysisConfig cfg; + SetConfig(&cfg); + + int num_ops; + auto predictor = CreatePaddlePredictor(cfg); + auto fuse_statis = GetFuseStatis( + static_cast(predictor.get()), &num_ops); + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + LOG(INFO) << "num_ops: " << num_ops; + if (FLAGS_ernie_large) { + ASSERT_EQ(fuse_statis.at("fc_fuse"), 146); + EXPECT_EQ(num_ops, 859); + } else { + ASSERT_EQ(fuse_statis.at("fc_fuse"), 74); + EXPECT_EQ(num_ops, 295); + } +} + +// Compare result of NativeConfig and AnalysisConfig +void compare(bool use_mkldnn = false) { + AnalysisConfig cfg; + SetConfig(&cfg, use_mkldnn, false); + + std::vector> inputs; + LoadInputData(&inputs); + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), inputs); +} + +TEST(Analyzer_ernie, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_ernie, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif + +// Compare Deterministic result +TEST(Analyzer_Ernie, compare_determine) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + CompareDeterministic(reinterpret_cast(&cfg), + input_slots_all); +} + +// Compare results +TEST(Analyzer_Ernie, compare_results) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + auto outputs_size = outputs.front().data.length() / (sizeof(float)); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], + static_cast(outputs[0].data.data())[j], + FLAGS_accuracy); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 99bb8251f08e5..6d159c45dadff 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -44,6 +44,7 @@ DEFINE_string(int8_model, "", "INT8 model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_string(refer_result, "", "reference result for comparison"); DEFINE_int32(batch_size, 1, "batch size"); +DEFINE_bool(ernie_large, false, "Test ernie large"); DEFINE_bool(with_accuracy_layer, true, "Calculate the accuracy while label is in the input"); DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction"); diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 484c4baef94de..93ee69d3033d7 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -32,17 +32,33 @@ class FCOp : public framework::OperatorWithKernel { auto in_dims = ctx->GetInputDim("Input"); auto w_dims = ctx->GetInputDim("W"); + bool padding_weights = ctx->Attrs().Get("padding_weights"); if (ctx->HasInput("Bias")) { auto bias_dims = ctx->GetInputDim("Bias"); + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; if (bias_dims.size() == 2) { PADDLE_ENFORCE_EQ(bias_dims[0], 1, - "The shape of Bias must be [1, dim]."); - PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], - "The shape of Bias must be [1, dim]."); + platform::errors::InvalidArgument( + "The shape of Bias is invalid." + "The height of Bias should be 1." + "But received height of Bias is %d.", + bias_dims[0])); + PADDLE_ENFORCE_EQ( + bias_dims[1], w_dims1, + platform::errors::InvalidArgument( + "The shape of Bias is invalid." + "The width of Bias should be equal to width of Weight." + "But received width of Bias is %d and width of Weight is %d.", + bias_dims[1], w_dims1)); } else if (bias_dims.size() == 1) { - PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1], - "The shape of Bias must be [1, dim]."); + PADDLE_ENFORCE_EQ( + bias_dims[0], w_dims1, + platform::errors::InvalidArgument( + "The shape of Bias is invalid." + "The height of Bias should be equal to the width of weight." + "But received height of Bias is %d and width of Weight is %d.", + bias_dims[0], w_dims1)); } } @@ -65,7 +81,8 @@ class FCOp : public framework::OperatorWithKernel { "in_num_col_dims."); std::vector output_dims; - FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims); + FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims, + padding_weights); ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("Input", "Out"); @@ -138,6 +155,11 @@ class FCOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "padding_weights", + "(bool, default false) When padding weights in the fc fuse pass, " + "the 'padding_weights' attribute is set as true.") + .SetDefault(false); AddAttr(framework::kAllKernelsMustComputeRuntimeShape, "Skip calling InferShape() function in the runtime.") .SetDefault(true); diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h index bf08e6ba6866e..c978c34b8fd30 100644 --- a/paddle/fluid/operators/fc_op.h +++ b/paddle/fluid/operators/fc_op.h @@ -38,17 +38,21 @@ class FCOpGrad : public framework::OperatorWithKernel { inline void FCOutputSize(const framework::DDim& in_dims, const framework::DDim& w_dims, std::vector& out_dims, // NOLINT - int in_num_col_dims) { + int in_num_col_dims, bool padding_weights) { auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims); - PADDLE_ENFORCE_EQ( - in_mat_dims[1], w_dims[0], - "Fully Connected input and weigth size do not match. %s, %s"); + auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + PADDLE_ENFORCE_EQ(in_mat_dims[1], w_dims0, + platform::errors::InvalidArgument( + "Fully Connected input and weigth size do not match. " + "input width: %d,weight height: %d", + in_mat_dims[1], w_dims0)); out_dims.reserve(static_cast(in_num_col_dims + 1)); for (int i = 0; i < in_num_col_dims; ++i) { out_dims.push_back(in_dims[i]); } - out_dims.push_back(w_dims[1]); + out_dims.push_back(w_dims1); } template @@ -64,14 +68,18 @@ class FCOpKernel : public framework::OpKernel { (ctx.Attr("activation_type") == "relu") ? true : false; auto w_dims = w->dims(); + bool padding_weights = ctx.Attr("padding_weights"); std::vector output_dims; - FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims); + FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims, + padding_weights); output->Resize(framework::make_ddim(output_dims)); output->set_lod(input->lod()); auto out_dims = output->dims(); - int M = framework::product(out_dims) / w_dims[1]; + auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; + auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; + int M = framework::product(out_dims) / w_dims1; const T* input_data = input->data(); const T* w_data = w->data(); @@ -79,8 +87,8 @@ class FCOpKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); math::FCFunctor fc; - fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data, - bias ? bias->data() : NULL, with_relu); + fc(dev_ctx, M, w_dims1, w_dims0, input_data, w_data, output_data, + bias ? bias->data() : NULL, with_relu, padding_weights); } }; diff --git a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc index a4e3246f10495..61d8c50c56825 100644 --- a/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc +++ b/paddle/fluid/operators/jit/more/intrinsic/layer_norm.cc @@ -26,131 +26,141 @@ namespace intrinsic { void LayerNorm(float* x, float* out, float* mean, float* var, const float* scale, const float* bias, int height, const float epsilon, int right) { - __m256 sum; - __m256 mean_vec, var_vec; - __m128 hi, lo; - __m256 tmp; - size_t offset; - size_t j; int block = YMM_FLOAT_BLOCK; const int rest = right % block; const int end = right - rest; +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel + { +#endif + __m256 sum; + __m256 mean_vec, var_vec; + __m128 hi, lo; + __m256 tmp; + size_t offset; + size_t j; + __m256 reverse_num_vec = + _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + int rest_mask = + ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; + __m256i mask_vec = _mm256_set_epi32( + rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, + rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, + rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, + rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); - __m256 reverse_num_vec = - _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right)); - __m256 epsilon_vec = _mm256_set1_ps(epsilon); - int rest_mask = - ((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff; - __m256i mask_vec = _mm256_set_epi32( - rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, - rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, - rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, - rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); +#ifdef PADDLE_WITH_MKLML +#pragma omp for +#endif + for (int i = 0; i < height; ++i) { + offset = i * right; - for (int i = 0; i < height; ++i) { - offset = i * right; - - /* get mean */ - sum = _mm256_setzero_ps(); - for (j = offset; j < end + offset; j += block) { - sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)x + j); - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, - *(__m256*)&mask_vec); // NOLINT - sum = _mm256_add_ps(sum, tmp); - } - hi = _mm256_extractf128_ps(sum, 1); - lo = _mm256_extractf128_ps(sum, 0); - sum = _mm256_add_ps( - sum, _mm256_insertf128_ps( - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); - sum = _mm256_hadd_ps(sum, sum); - sum = _mm256_hadd_ps(sum, sum); - mean_vec = _mm256_mul_ps(sum, reverse_num_vec); - mean[i] = *reinterpret_cast(&mean_vec); - - /* get variance */ - sum = _mm256_setzero_ps(); - for (j = offset; j < end + offset; j += block) { - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_mul_ps(tmp, tmp); - sum = _mm256_add_ps(sum, tmp); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_mul_ps(tmp, tmp); - tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, - *(__m256*)&mask_vec); // NOLINT - sum = _mm256_add_ps(sum, tmp); - } - hi = _mm256_extractf128_ps(sum, 1); - lo = _mm256_extractf128_ps(sum, 0); - sum = _mm256_add_ps( - sum, _mm256_insertf128_ps( - _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); - sum = _mm256_hadd_ps(sum, sum); - sum = _mm256_hadd_ps(sum, sum); - var_vec = _mm256_mul_ps(sum, reverse_num_vec); - var[i] = *reinterpret_cast(&var_vec); - - /* get x_norm and calculate output*/ - for (j = offset; j < end + offset; j += block) { - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_div_ps(tmp, - _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); - } - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); - tmp = _mm256_div_ps(tmp, - _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); - _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); - } - - if (scale) { - if (rest != 0) { - j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)out + j); - } + /* get mean */ + sum = _mm256_setzero_ps(); for (j = offset; j < end + offset; j += block) { - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), - _mm256_loadu_ps((const float*)scale + j - offset))); + sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); } if (rest != 0) { j = offset + right - block; - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_mul_ps(tmp, - _mm256_loadu_ps((const float*)scale + j - offset))); + tmp = _mm256_loadu_ps((const float*)x + j); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); } - } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + mean_vec = _mm256_mul_ps(sum, reverse_num_vec); + mean[i] = *reinterpret_cast(&mean_vec); - if (bias) { + /* get variance */ + sum = _mm256_setzero_ps(); + for (j = offset; j < end + offset; j += block) { + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + sum = _mm256_add_ps(sum, tmp); + } if (rest != 0) { j = offset + right - block; - tmp = _mm256_loadu_ps((const float*)out + j); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_mul_ps(tmp, tmp); + tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, + *(__m256*)&mask_vec); // NOLINT + sum = _mm256_add_ps(sum, tmp); } + hi = _mm256_extractf128_ps(sum, 1); + lo = _mm256_extractf128_ps(sum, 0); + sum = _mm256_add_ps( + sum, _mm256_insertf128_ps( + _mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); + sum = _mm256_hadd_ps(sum, sum); + sum = _mm256_hadd_ps(sum, sum); + var_vec = _mm256_mul_ps(sum, reverse_num_vec); + var[i] = *reinterpret_cast(&var_vec); + + /* get x_norm and calculate output*/ for (j = offset; j < end + offset; j += block) { - _mm256_storeu_ps( - reinterpret_cast(out) + j, - _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), - _mm256_loadu_ps((const float*)bias + j - offset))); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps( + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); } if (rest != 0) { j = offset + right - block; - _mm256_storeu_ps(reinterpret_cast(out) + j, - _mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias + - j - offset))); + tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); + tmp = _mm256_div_ps( + tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); + _mm256_storeu_ps(reinterpret_cast(out) + j, tmp); + } + + if (scale) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)scale + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_mul_ps(tmp, + _mm256_loadu_ps((const float*)scale + j - offset))); + } + } + + if (bias) { + if (rest != 0) { + j = offset + right - block; + tmp = _mm256_loadu_ps((const float*)out + j); + } + for (j = offset; j < end + offset; j += block) { + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(_mm256_loadu_ps((const float*)out + j), + _mm256_loadu_ps((const float*)bias + j - offset))); + } + if (rest != 0) { + j = offset + right - block; + _mm256_storeu_ps( + reinterpret_cast(out) + j, + _mm256_add_ps(tmp, + _mm256_loadu_ps((const float*)bias + j - offset))); + } } } +#ifdef PADDLE_WITH_MKLML } +#endif } bool LayerNormKernel::CanBeUsed(const int& d) const { diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc index b5479a1b43568..9309a992f73a0 100644 --- a/paddle/fluid/operators/math/fc.cc +++ b/paddle/fluid/operators/math/fc.cc @@ -25,31 +25,56 @@ class FCFunctor { public: void operator()(const platform::CPUDeviceContext& context, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false) { + const T* B = nullptr, bool relu = false, + bool padding_weights = false) { auto blas = math::GetBlas(context); - blas.MatMul(M, N, K, X, W, Y); - if (B == NULL) { - return; - } - if (relu) { - auto compute = - jit::KernelFuncs, platform::CPUPlace>::Cache() - .At(N); + framework::Tensor Y1; + T* Y1_data = nullptr; + if (padding_weights) { + const int NN = N + 4; + const int KK = K + 4; + framework::Tensor X1; + T* X1_data = X1.mutable_data({M * KK}, platform::CPUPlace()); + Y1_data = Y1.mutable_data({M * (N + 4)}, platform::CPUPlace()); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - compute(B, dst, dst, N); + memcpy(X1_data + i * KK, X + i * K, K * sizeof(T)); } + blas.GEMM(false, false, M, N, K, static_cast(1.0), X1_data, KK, W, NN, + static_cast(0.0), Y1_data, NN); } else { - auto compute = - jit::KernelFuncs, platform::CPUPlace>::Cache().At( - N); + blas.MatMul(M, N, K, X, W, Y); + } + if (B == NULL) { + if (padding_weights) { #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif - for (int i = 0; i < M; i++) { - T* dst = Y + i * N; - compute(B, dst, dst, N); + for (int i = 0; i < M; i++) { + memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); + } } + PADDLE_ENFORCE_EQ(relu, false, + platform::errors::PermissionDenied( + "When bias is NULL, relu can not be true.")); + return; + } + auto compute = + relu + ? jit::KernelFuncs, + platform::CPUPlace>::Cache() + .At(N) + : jit::KernelFuncs, platform::CPUPlace>::Cache() + .At(N); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < M; i++) { + T* dst = Y + i * N; + T* src = (padding_weights) ? Y1_data + i * (N + 4) : dst; + compute(B, src, dst, N); } } }; diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu index 1b22b81039954..82da2dd805aef 100644 --- a/paddle/fluid/operators/math/fc.cu +++ b/paddle/fluid/operators/math/fc.cu @@ -41,7 +41,12 @@ class FCFunctor { public: void operator()(const platform::CUDADeviceContext& context, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false) { + const T* B = nullptr, bool relu = false, + bool padding_weights = false) { + PADDLE_ENFORCE_EQ( + padding_weights, false, + platform::errors::PermissionDenied( + "Weight padding in fc can not be used in GPU scope.")); auto blas = math::GetBlas(context); blas.GEMM(false, false, M, N, K, static_cast(1.0), X, K, W, N, static_cast(0.0), Y, N); diff --git a/paddle/fluid/operators/math/fc.h b/paddle/fluid/operators/math/fc.h index 9bef496fb9d39..02f81587c739f 100644 --- a/paddle/fluid/operators/math/fc.h +++ b/paddle/fluid/operators/math/fc.h @@ -26,7 +26,8 @@ class FCFunctor { public: void operator()(const DeviceContext& context, const int M, const int N, const int K, const T* X, const T* W, T* Y, - const T* B = nullptr, bool relu = false); + const T* B = nullptr, bool relu = false, + bool weight_pass = false); }; } // namespace math diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 349dbffb386f8..dfaf47653fac5 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -207,8 +207,13 @@ class FCPrimitiveFactory { void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input, const Tensor* w, LoDTensor* output) { int in_num_col_dims = ctx.Attr("in_num_col_dims"); + bool padding_weights = ctx.Attr("padding_weights"); + PADDLE_ENFORCE_EQ(padding_weights, false, + platform::errors::PermissionDenied( + "Weight padding in fc can not be used in MKLDNN.")); std::vector output_dims; - FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims); + FCOutputSize(input->dims(), w->dims(), output_dims, in_num_col_dims, + padding_weights); output->Resize(framework::make_ddim(output_dims)); output->set_lod(input->lod()); } diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py index 0da0fd0789a77..9028210b8fe9c 100644 --- a/python/paddle/fluid/tests/unittests/test_fc_op.py +++ b/python/paddle/fluid/tests/unittests/test_fc_op.py @@ -124,6 +124,13 @@ def config(self): self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1) +class TestFCOpWithPadding(TestFCOp): + def config(self): + self.with_bias = True + self.with_relu = True + self.matrix = MatrixGenerate(1, 4, 3, 128, 128, 2) + + class TestFCOpError(OpTest): def test_errors(self): with program_guard(Program(), Program()):