diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index f3fc183c910b8..f15bd26c4476a 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -64,8 +64,14 @@ if(WIN32) target_link_libraries(paddle_inference_api phi) endif() -set(inference_deps ${analysis_deps} paddle_inference_api analysis - analysis_config naive_executor ${GLOB_PASS_LIB}) +set(inference_deps + ${analysis_deps} + paddle_inference_api + analysis + analysis_config + naive_executor + ${GLOB_PASS_LIB} + transform) if(WITH_GPU AND TENSORRT_FOUND) set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 99b50c9b8ab28..5f90955ec7cf3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -103,6 +103,7 @@ #endif #include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/inplace_pass.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" @@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() { paddle::TranslateLegacyProgramToProgram(*inference_program_)); ::pir::PassManager pm(::pir::IrContext::Instance(), 2); + pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_)); pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); pm.AddPass(::pir::CreateDeadCodeEliminationPass()); - pm.EnableIRPrinting(); + // pm.EnableIRPrinting(); pm.Run(pir_program_.get()); pir_program_ = std::move( diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 4e36f1df9defa..39daebc1a3b8f 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -17,20 +17,21 @@ #include #include #include - -// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/pir/dialect/CMakeLists.txt. -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" + +#include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/ir_context.h" #include "paddle/pir/core/op_result.h" @@ -46,115 +47,101 @@ namespace { class ConstantFoldingPattern : public pir::RewritePattern { public: - ConstantFoldingPattern(pir::IrContext* context, - paddle::framework::Scope* scope, - pir::PatternBenefit benefit = 1, - const std::vector& generated_names = {}) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names), - scope_(scope) {} + ConstantFoldingPattern( + pir::IrContext* context, + size_t* suffix, + const phi::Place& place, + paddle::framework::Scope* scope, + paddle::framework::interpreter::ExecutionConfig* exe_config, + std::vector* deleted_vars) + : RewritePattern(MatchAnyOpTypeTag(), + 1 /*benefit*/, + context, + {} /*generated_names*/), + counter_(suffix), + place_(place), + scope_(scope), + exe_config_(exe_config), + deleted_vars_(deleted_vars) { + exe_config_->create_local_scope = false; + } bool Match(pir::Operation* op) const override { - // TODO(liuyuanle): Use trait to improve robustness. if (op->isa() || op->isa() || - op->isa() || - op->isa()) + op->isa() || op->isa() || + op->isa()) return false; - // Inputs must come from get parameter op. - for (uint32_t i = 0; i < op->num_operands(); ++i) - if (!pir::GetDefiningOpForInput(op, i)->isa()) - return false; + if (!ValidOp(op)) { + return false; + } + return true; } void Rewrite(pir::Operation* op, pir::PatternRewriter& rewriter) const override { // NOLINT - pir::Program* program = op->GetParentProgram(); - auto temp_program = BuildProgramFromOperation(op); - - std::vector fetch_var_names; - auto block = temp_program->block(); - for (auto it = block->begin(); it != block->end(); ++it) { - if ((*it)->isa()) { - size_t index = (*it) - ->attributes() - .at("col") - .dyn_cast() - .data(); - - if (fetch_var_names.size() < index + 1) { - fetch_var_names.resize(index + 1); - } - - fetch_var_names[index] = (*it) - ->attributes() - .at("name") - .dyn_cast() - .AsString() + - "@fetch"; - } - } + VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op"; + pir::Program new_program(ir_context()); + auto output_var_name = BuildProgramFromOperation(op, &new_program); - // Execute program - exe_config_.create_local_scope = false; + // execute program + exe_config_->skip_gc_vars.insert(output_var_name); auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(temp_program.get()); - paddle::framework::InterpreterCore core(phi::CPUPlace{}, - fetch_var_names, - kernel_program->block(), - scope_, - exe_config_); - - paddle::framework::FetchList fetch_list = core.Run({}); - - // TODO(liuyuanle): Support multiple output. - auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]); - std::unique_ptr parameter = - std::make_unique( - reinterpret_cast(out_tensor.data()), - out_tensor.numel() * phi::SizeOf(out_tensor.dtype()), - op->result(0).type()); - - std::string param_name = - "@constant_folding_pass@_" + std::to_string(suffix_++); - exe_config_.skip_gc_vars.insert(param_name); - - auto* param_var = scope_->Var(param_name); - auto* param_tensor = param_var->GetMutable(); - *param_tensor = out_tensor; - program->SetParameter(param_name, std::move(parameter)); - // rewriter.SetInsertionPoint(op); - auto get_parameter_op = - rewriter.Build(param_name, op->result(0).type()); + paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); + paddle::framework::InterpreterCore core( + place_, {}, kernel_program->block(), scope_, *exe_config_); + core.Run({}); + + // TODO(liuyuanle): support multiple output + auto get_parameter_op = rewriter.Build( + output_var_name, op->result(0).type()); + get_parameter_op->set_attribute( + kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); + + VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op"; rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0)); rewriter.EraseOp(op); } private: - std::unique_ptr BuildProgramFromOperation( - pir::Operation* op) const { - auto program = std::make_unique(ir_context()); - pir::Builder builder = pir::Builder(ir_context(), program->block()); + bool ValidOp(pir::Operation* op) const { + for (uint32_t i = 0; i < op->num_operands(); i++) { + // 1. inputs must come from get_parameter op + // 2. inputs must be a dense tensor type + if (!pir::GetDefiningOpForInput(op, i)->isa() || + !op->operand_source(i) + .type() + .isa()) { + return false; + } + // 3. outputs must be a dense tensor type + for (uint32_t i = 0; i < op->num_results(); i++) { + if (!op->result(i).type().isa()) { + return false; + } + } + } + return true; + } + + std::string BuildProgramFromOperation(pir::Operation* op, + pir::Program* new_program) const { + pir::Builder builder = pir::Builder(ir_context(), new_program->block()); // prepare op inputs std::vector op_inputs; for (uint32_t i = 0; i < op->num_operands(); i++) { - PADDLE_ENFORCE_EQ( - op->operand_source(i).type().isa(), - true, - phi::errors::InvalidArgument( - "Op's input must be a dense tensor type.")); - - auto [param_name, param] = - pir::GetParameterFromValue(op->operand_source(i)); - program->SetParameter(param_name, - std::make_unique(*param)); - + const auto& param_name = + pir::GetParameterNameFromValue(op->operand_source(i)); auto* param_var = scope_->FindVar(param_name); PADDLE_ENFORCE_NOT_NULL( param_var, phi::errors::InvalidArgument("Parameter var not in scope.")); + if (op->operand_source(i).use_count() == 1) { + deleted_vars_->push_back(param_name); + } auto get_parameter_op = builder.Build( param_name, op->operand_source(i).type()); @@ -170,60 +157,80 @@ class ConstantFoldingPattern : public pir::RewritePattern { auto* temp_op = builder.Build(op_inputs, op->attributes(), output_types, op->info()); - // TODO(liuyuanle): Support multiple output. + // TODO(liuyuanle): support multiple output // for (uint32_t i = 0; i < op->num_results(); i++) { - PADDLE_ENFORCE_EQ( - temp_op->result(0).type().isa(), - true, - phi::errors::InvalidArgument( - "Op's output must be a dense tensor type.")); - - builder.Build( - temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0); + + std::stringstream ss; + ss << std::chrono::high_resolution_clock::now().time_since_epoch().count(); + std::string output_var_name = + "constant_folding@_" + ss.str() + std::to_string((*counter_)++); + + builder.Build(temp_op->result(0), output_var_name); // } - return program; + return output_var_name; } private: + size_t* counter_{nullptr}; + phi::Place place_; paddle::framework::Scope* scope_{nullptr}; - inline static size_t suffix_{0}; - inline static paddle::framework::interpreter::ExecutionConfig exe_config_{}; + paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr}; + std::vector* deleted_vars_{nullptr}; }; class ConstantFoldingPass : public pir::Pass { public: - ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} + ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope) + : pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) { + PADDLE_ENFORCE_NOT_NULL( + scope_, phi::errors::InvalidArgument("scope can not be nullptr")); + } + private: bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); - ps.Add(context, &scope_); + ps.Add( + context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } void Run(pir::Operation* op) override { + size_t op_nums = op->GetParentProgram()->block()->size(); pir::GreedyRewriteConfig cfg; cfg.use_top_down_traversal = true; cfg.max_iterations = 10; pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg); + + // delete old parameter var + scope_->EraseVars(deleted_vars_); + LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/" + << op_nums << "]"; } bool CanApplyOn(pir::Operation* op) const override { + // TODO(liuyuanle): remove op->isa<::pir::ModuleOp>() return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; } private: + size_t counter_{0}; + phi::Place place_; + paddle::framework::Scope* scope_{nullptr}; + paddle::framework::interpreter::ExecutionConfig exe_config_{}; + std::vector deleted_vars_; + pir::FrozenRewritePatternSet patterns_; - paddle::framework::Scope scope_; }; } // namespace namespace pir { -std::unique_ptr CreateConstantFoldingPass() { - return std::make_unique(); +std::unique_ptr CreateConstantFoldingPass( + const phi::Place& place, paddle::framework::Scope* scope) { + return std::make_unique(place, scope); } } // namespace pir diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.h b/paddle/fluid/pir/transforms/constant_folding_pass.h index b49c9d90493b1..0939ee589d448 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.h +++ b/paddle/fluid/pir/transforms/constant_folding_pass.h @@ -15,12 +15,20 @@ #pragma once #include +#include "paddle/phi/common/place.h" #include "paddle/pir/core/dll_decl.h" +namespace paddle { +namespace framework { +class Scope; +} +} // namespace paddle + namespace pir { class Pass; -IR_API std::unique_ptr CreateConstantFoldingPass(); +IR_API std::unique_ptr CreateConstantFoldingPass( + const phi::Place& place, paddle::framework::Scope* scope); } // namespace pir diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc index 7535ddeb513db..9c6fcd9b3d9ca 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" +#include "paddle/pir/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" @@ -35,16 +36,16 @@ class DeadCodeEliminationPattern : public pir::RewritePattern { } bool Match(pir::Operation* op) const override { - if (op->isa() || - op->isa()) + if (op->isa() || op->isa() || + op->isa()) { return false; - + } return op->use_empty(); } void Rewrite(pir::Operation* op, pir::PatternRewriter& rewriter) const override { // NOLINT - if (op->dyn_cast()) { + if (op->isa()) { // Delete parameter from program. pir::GetParameterOp get_parameter_op = op->dyn_cast(); diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/transforms/transform_general_functions.cc index 8bd5028688b13..1d7c226197668 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/transforms/transform_general_functions.cc @@ -22,8 +22,7 @@ namespace pir { -std::pair GetParameterFromValue( - pir::Value value) { +std::string GetParameterNameFromValue(pir::Value value) { pir::GetParameterOp op = value.dyn_cast().owner()->dyn_cast(); PADDLE_ENFORCE_NOT_NULL( @@ -37,10 +36,7 @@ std::pair GetParameterFromValue( .at(op.attributes_name[0]) .dyn_cast() .AsString(); - pir::Parameter* param = program->GetParameter(name); - PADDLE_ENFORCE_NOT_NULL( - param, phi::errors::InvalidArgument("Parameter should not be null.")); - return {name, param}; + return name; } const phi::DDim& GetShapeFromValue(pir::Value value) { diff --git a/paddle/fluid/pir/transforms/transform_general_functions.h b/paddle/fluid/pir/transforms/transform_general_functions.h index 77c790235b832..0d35ff776ce8c 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.h +++ b/paddle/fluid/pir/transforms/transform_general_functions.h @@ -25,16 +25,16 @@ namespace pir { /** - * @brief Get the [name, parameter] pair of pararmeter from a value. + * @brief Get the name of pararmeter from a value. * * @note The value must be a output of a GetParameterOp. * * @param pir::Value * - * @return std::pair + * @return std::string */ -std::pair GetParameterFromValue(pir::Value value); +std::string GetParameterNameFromValue(pir::Value value); /** * @brief Get tensor's shape from a value. diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index d45a74f6fd0d1..df112a6a3f44c 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -20,11 +20,15 @@ #include #include +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/constant_folding_pass.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/phi/core/kernel_registry.h" + #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_dialect.h" @@ -44,13 +48,9 @@ #include "paddle/pir/pattern_rewrite/pattern_match.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in -// paddle/fluid/pir/dialect/CMakeLists.txt. -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" - -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" // build Conv2dFusionOp #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" @@ -1120,21 +1120,22 @@ TEST(pattern_rewrite, Patterns) { BuildProgram(builder); EXPECT_EQ(program.block()->size(), 11u); - + paddle::framework::Scope scope; pir::PassManager pm(ctx); pm.AddPass(std::make_unique()); - pm.AddPass(pir::CreateConstantFoldingPass()); + pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope)); pm.AddPass(pir::CreateDeadCodeEliminationPass()); pm.EnablePassTiming(); - pm.EnableIRPrinting(std::make_unique( - [](pir::Pass *pass, pir::Operation *op) { - return pass->name() == "constant_folding_pass"; - }, - [](pir::Pass *pass, pir::Operation *op) { - return pass->name() == "constant_folding_pass"; - }, - true, - true)); + pm.EnableIRPrinting(); + // pm.EnableIRPrinting(std::make_unique( + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // [](pir::Pass *pass, pir::Operation *op) { + // return pass->name() == "constant_folding_pass"; + // }, + // true, + // true)); CHECK_EQ(pm.Run(&program), true); EXPECT_EQ(program.block()->size(), 2u); diff --git a/test/ir/inference/test_inference_predictor_run.py b/test/ir/inference/test_inference_predictor_run.py index 1c552bc82b77e..1d8abc174f1cf 100644 --- a/test/ir/inference/test_inference_predictor_run.py +++ b/test/ir/inference/test_inference_predictor_run.py @@ -62,8 +62,10 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + def enable_pir(self, flag: bool): + paddle.set_flags({'FLAGS_enable_pir_in_executor': flag}) + def init_predictor(self): - paddle.set_flags({'FLAGS_enable_pir_in_executor': True}) config = Config( os.path.join( self.temp_dir.name, @@ -115,12 +117,15 @@ def get_inorder_output(self, predictor): return outputs[0] def test_output(self): + self.enable_pir(False) predictor = self.init_predictor() - inorder_output = self.get_inorder_output(predictor) - disorder_output = self.get_disorder_output(predictor) + output = self.get_inorder_output(predictor) + self.enable_pir(True) + pir_predictor = self.init_predictor() + pir_output = self.get_disorder_output(pir_predictor) np.testing.assert_allclose( - inorder_output.numpy().flatten(), disorder_output.numpy().flatten() + output.numpy().flatten(), pir_output.numpy().flatten() )