Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Support constant_folding_pass on PIR #58753

Merged
merged 9 commits into from Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 12 additions & 2 deletions paddle/fluid/inference/api/CMakeLists.txt
Expand Up @@ -64,8 +64,18 @@ if(WIN32)
target_link_libraries(paddle_inference_api phi)
endif()

set(inference_deps ${analysis_deps} paddle_inference_api analysis
analysis_config naive_executor ${GLOB_PASS_LIB})
set(PIR_PASS_DEPS
pd_constant_folding_pass dead_code_elimination_pass pd_op_to_kernel_pass
pd_inplace_pass replace_fetch_with_shadow_output_pass)

set(inference_deps
${analysis_deps}
paddle_inference_api
analysis
analysis_config
naive_executor
${GLOB_PASS_LIB}
${PIR_PASS_DEPS})

if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Expand Up @@ -103,6 +103,7 @@
#endif

#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
Expand Down Expand Up @@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() {
paddle::TranslateLegacyProgramToProgram(*inference_program_));

::pir::PassManager pm(::pir::IrContext::Instance(), 2);
pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_));
pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
pm.AddPass(::pir::CreateDeadCodeEliminationPass());

pm.EnableIRPrinting();
// pm.EnableIRPrinting();
pm.Run(pir_program_.get());

pir_program_ = std::move(
Expand Down
172 changes: 83 additions & 89 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Expand Up @@ -17,20 +17,21 @@
#include <memory>
#include <string>
#include <unordered_map>

// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include <vector>

#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/op_result.h"
Expand All @@ -46,21 +47,32 @@ namespace {

class ConstantFoldingPattern : public pir::RewritePattern {
public:
ConstantFoldingPattern(pir::IrContext* context,
paddle::framework::Scope* scope,
pir::PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names),
scope_(scope) {}
ConstantFoldingPattern(
pir::IrContext* context,
size_t* suffix,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
context,
{} /*generated_names*/),
counter_(suffix),
place_(place),
scope_(scope),
exe_config_(exe_config),
deleted_vars_(deleted_vars) {
exe_config_->create_local_scope = false;
}

bool Match(pir::Operation* op) const override {
// TODO(liuyuanle): Use trait to improve robustness.
if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::FeedOp>())
return false;

// Inputs must come from get parameter op.
// inputs must come from get parameter op
for (uint32_t i = 0; i < op->num_operands(); ++i)
if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>())
return false;
Expand All @@ -69,73 +81,36 @@ class ConstantFoldingPattern : public pir::RewritePattern {

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
pir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op);

std::vector<std::string> fetch_var_names;
auto block = temp_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
if ((*it)->isa<paddle::dialect::FetchOp>()) {
size_t index = (*it)
->attributes()
.at("col")
.dyn_cast<pir::Int32Attribute>()
.data();

if (fetch_var_names.size() < index + 1) {
fetch_var_names.resize(index + 1);
}

fetch_var_names[index] = (*it)
->attributes()
.at("name")
.dyn_cast<pir::StrAttribute>()
.AsString() +
"@fetch";
}
}
VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op";
pir::Program new_program(ir_context());
auto output_var_name = BuildProgramFromOperation(op, &new_program);

// Execute program
exe_config_.create_local_scope = false;
// execute program
exe_config_->skip_gc_vars.insert(output_var_name);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
paddle::framework::InterpreterCore core(phi::CPUPlace{},
fetch_var_names,
kernel_program->block(),
scope_,
exe_config_);

paddle::framework::FetchList fetch_list = core.Run({});

// TODO(liuyuanle): Support multiple output.
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
std::unique_ptr<pir::Parameter> parameter =
std::make_unique<pir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op->result(0).type());

std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
exe_config_.skip_gc_vars.insert(param_name);

auto* param_var = scope_->Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
*param_tensor = out_tensor;
program->SetParameter(param_name, std::move(parameter));
// rewriter.SetInsertionPoint(op);
auto get_parameter_op =
rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type());
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});

// TODO(liuyuanle): support multiple output
auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
output_var_name, op->result(0).type());
get_parameter_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op";
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}

private:
std::unique_ptr<pir::Program> BuildProgramFromOperation(
pir::Operation* op) const {
auto program = std::make_unique<pir::Program>(ir_context());
pir::Builder builder = pir::Builder(ir_context(), program->block());
std::string BuildProgramFromOperation(pir::Operation* op,
pir::Program* new_program) const {
pir::Builder builder = pir::Builder(ir_context(), new_program->block());
std::string output_var_name =
"constant_folding@" + std::to_string((*counter_)++);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用counter做唯一标识输出var name感觉不是很安全

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面还有操作,拼接了被折叠掉的name
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

提醒了我,这里有个潜在问题,随着折叠的深度,name拼接后会越来越长,我修一下


// prepare op inputs
std::vector<pir::Value> op_inputs;
Expand All @@ -146,15 +121,14 @@ class ConstantFoldingPattern : public pir::RewritePattern {
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));

auto [param_name, param] =
pir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<pir::Parameter>(*param));

const auto& param_name =
pir::GetParameterNameFromValue(op->operand_source(i));
auto* param_var = scope_->FindVar(param_name);
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
output_var_name = output_var_name + "_" + param_name;
deleted_vars_->push_back(param_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需不需要判断下param有没有被其他op使用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks!


auto get_parameter_op = builder.Build<pir::GetParameterOp>(
param_name, op->operand_source(i).type());
Expand All @@ -170,60 +144,80 @@ class ConstantFoldingPattern : public pir::RewritePattern {
auto* temp_op =
builder.Build(op_inputs, op->attributes(), output_types, op->info());

// TODO(liuyuanle): Support multiple output.
// TODO(liuyuanle): support multiple output
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ(
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's output must be a dense tensor type."));

builder.Build<paddle::dialect::FetchOp>(
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
// }

return program;
return output_var_name;
}

private:
size_t* counter_{nullptr};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
inline static size_t suffix_{0};
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr};
std::vector<std::string>* deleted_vars_{nullptr};
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}
ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope)
: pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) {
PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
}

private:
bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(context, &scope_);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation* op) override {
size_t op_nums = op->GetParentProgram()->block()->size();
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);

// delete old parameter var
scope_->EraseVars(deleted_vars_);
LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/"
<< op_nums << "]";
}

bool CanApplyOn(pir::Operation* op) const override {
// TODO(liuyuanle): remove op->isa<::pir::ModuleOp>()
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
size_t counter_{0};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
paddle::framework::interpreter::ExecutionConfig exe_config_{};
std::vector<std::string> deleted_vars_;

pir::FrozenRewritePatternSet patterns_;
paddle::framework::Scope scope_;
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ConstantFoldingPass>(place, scope);
}

} // namespace pir
10 changes: 9 additions & 1 deletion paddle/fluid/pir/transforms/constant_folding_pass.h
Expand Up @@ -15,12 +15,20 @@
#pragma once

#include <memory>
#include "paddle/phi/common/place.h"
#include "paddle/pir/core/dll_decl.h"

namespace paddle {
namespace framework {
class Scope;
}
} // namespace paddle

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope);

} // namespace pir
7 changes: 3 additions & 4 deletions paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
Expand Up @@ -35,16 +35,15 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
}

bool Match(pir::Operation* op) const override {
if (op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>()) {
return false;

}
return op->use_empty();
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
if (op->dyn_cast<pir::GetParameterOp>()) {
if (op->isa<pir::GetParameterOp>()) {
// Delete parameter from program.
pir::GetParameterOp get_parameter_op =
op->dyn_cast<pir::GetParameterOp>();
Expand Down
8 changes: 2 additions & 6 deletions paddle/fluid/pir/transforms/transform_general_functions.cc
Expand Up @@ -22,8 +22,7 @@

namespace pir {

std::pair<std::string, pir::Parameter*> GetParameterFromValue(
pir::Value value) {
std::string GetParameterNameFromValue(pir::Value value) {
pir::GetParameterOp op =
value.dyn_cast<OpResult>().owner()->dyn_cast<pir::GetParameterOp>();
PADDLE_ENFORCE_NOT_NULL(
Expand All @@ -37,10 +36,7 @@ std::pair<std::string, pir::Parameter*> GetParameterFromValue(
.at(op.attributes_name[0])
.dyn_cast<pir::StrAttribute>()
.AsString();
pir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
return {name, param};
return name;
}

const phi::DDim& GetShapeFromValue(pir::Value value) {
Expand Down