Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Support constant_folding_pass on PIR #58753

Merged
merged 9 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/fluid/inference/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ if(WIN32)
target_link_libraries(paddle_inference_api phi)
endif()

set(inference_deps ${analysis_deps} paddle_inference_api analysis
analysis_config naive_executor ${GLOB_PASS_LIB})
set(inference_deps
${analysis_deps}
paddle_inference_api
analysis
analysis_config
naive_executor
${GLOB_PASS_LIB}
transform)

if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} tensorrt_engine tensorrt_converter)
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
#endif

#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/fluid/pir/transforms/constant_folding_pass.h"
#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h"
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
Expand Down Expand Up @@ -731,10 +732,11 @@ bool AnalysisPredictor::PrepareExecutor() {
paddle::TranslateLegacyProgramToProgram(*inference_program_));

::pir::PassManager pm(::pir::IrContext::Instance(), 2);
pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_));
pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
pm.AddPass(::pir::CreateDeadCodeEliminationPass());

pm.EnableIRPrinting();
// pm.EnableIRPrinting();
pm.Run(pir_program_.get());

pir_program_ = std::move(
Expand Down
215 changes: 111 additions & 104 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,21 @@
#include <memory>
#include <string>
#include <unordered_map>

// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include <vector>

#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"

#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/op_result.h"
Expand All @@ -46,115 +47,101 @@ namespace {

class ConstantFoldingPattern : public pir::RewritePattern {
public:
ConstantFoldingPattern(pir::IrContext* context,
paddle::framework::Scope* scope,
pir::PatternBenefit benefit = 1,
const std::vector<std::string>& generated_names = {})
: RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names),
scope_(scope) {}
ConstantFoldingPattern(
pir::IrContext* context,
size_t* suffix,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
context,
{} /*generated_names*/),
counter_(suffix),
place_(place),
scope_(scope),
exe_config_(exe_config),
deleted_vars_(deleted_vars) {
exe_config_->create_local_scope = false;
}

bool Match(pir::Operation* op) const override {
// TODO(liuyuanle): Use trait to improve robustness.
if (op->isa<pir::GetParameterOp>() || op->isa<pir::SetParameterOp>() ||
op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
op->isa<pir::ShadowOutputOp>() || op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::FeedOp>())
return false;

// Inputs must come from get parameter op.
for (uint32_t i = 0; i < op->num_operands(); ++i)
if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>())
return false;
if (!ValidOp(op)) {
return false;
}

return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
pir::Program* program = op->GetParentProgram();
auto temp_program = BuildProgramFromOperation(op);

std::vector<std::string> fetch_var_names;
auto block = temp_program->block();
for (auto it = block->begin(); it != block->end(); ++it) {
if ((*it)->isa<paddle::dialect::FetchOp>()) {
size_t index = (*it)
->attributes()
.at("col")
.dyn_cast<pir::Int32Attribute>()
.data();

if (fetch_var_names.size() < index + 1) {
fetch_var_names.resize(index + 1);
}

fetch_var_names[index] = (*it)
->attributes()
.at("name")
.dyn_cast<pir::StrAttribute>()
.AsString() +
"@fetch";
}
}
VLOG(4) << "constant_folding_pass applys on [" << op->name() << "] op";
pir::Program new_program(ir_context());
auto output_var_name = BuildProgramFromOperation(op, &new_program);

// Execute program
exe_config_.create_local_scope = false;
// execute program
exe_config_->skip_gc_vars.insert(output_var_name);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(temp_program.get());
paddle::framework::InterpreterCore core(phi::CPUPlace{},
fetch_var_names,
kernel_program->block(),
scope_,
exe_config_);

paddle::framework::FetchList fetch_list = core.Run({});

// TODO(liuyuanle): Support multiple output.
auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
std::unique_ptr<pir::Parameter> parameter =
std::make_unique<pir::Parameter>(
reinterpret_cast<void*>(out_tensor.data()),
out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
op->result(0).type());

std::string param_name =
"@constant_folding_pass@_" + std::to_string(suffix_++);
exe_config_.skip_gc_vars.insert(param_name);

auto* param_var = scope_->Var(param_name);
auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
*param_tensor = out_tensor;
program->SetParameter(param_name, std::move(parameter));
// rewriter.SetInsertionPoint(op);
auto get_parameter_op =
rewriter.Build<pir::GetParameterOp>(param_name, op->result(0).type());
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});

// TODO(liuyuanle): support multiple output
auto get_parameter_op = rewriter.Build<pir::GetParameterOp>(
output_var_name, op->result(0).type());
get_parameter_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

VLOG(4) << "constant_folding_pass applied on [" << op->name() << "] op";
rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
rewriter.EraseOp(op);
}

private:
std::unique_ptr<pir::Program> BuildProgramFromOperation(
pir::Operation* op) const {
auto program = std::make_unique<pir::Program>(ir_context());
pir::Builder builder = pir::Builder(ir_context(), program->block());
bool ValidOp(pir::Operation* op) const {
for (uint32_t i = 0; i < op->num_operands(); i++) {
// 1. inputs must come from get_parameter op
// 2. inputs must be a dense tensor type
if (!pir::GetDefiningOpForInput(op, i)->isa<pir::GetParameterOp>() ||
!op->operand_source(i)
.type()
.isa<paddle::dialect::DenseTensorType>()) {
return false;
}
// 3. outputs must be a dense tensor type
for (uint32_t i = 0; i < op->num_results(); i++) {
if (!op->result(i).type().isa<paddle::dialect::DenseTensorType>()) {
return false;
}
}
}
return true;
}

std::string BuildProgramFromOperation(pir::Operation* op,
pir::Program* new_program) const {
pir::Builder builder = pir::Builder(ir_context(), new_program->block());

// prepare op inputs
std::vector<pir::Value> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));

auto [param_name, param] =
pir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<pir::Parameter>(*param));

const auto& param_name =
pir::GetParameterNameFromValue(op->operand_source(i));
auto* param_var = scope_->FindVar(param_name);
PADDLE_ENFORCE_NOT_NULL(
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
if (op->operand_source(i).use_count() == 1) {
deleted_vars_->push_back(param_name);
}

auto get_parameter_op = builder.Build<pir::GetParameterOp>(
param_name, op->operand_source(i).type());
Expand All @@ -170,60 +157,80 @@ class ConstantFoldingPattern : public pir::RewritePattern {
auto* temp_op =
builder.Build(op_inputs, op->attributes(), output_types, op->info());

// TODO(liuyuanle): Support multiple output.
// TODO(liuyuanle): support multiple output
// for (uint32_t i = 0; i < op->num_results(); i++) {
PADDLE_ENFORCE_EQ(
temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's output must be a dense tensor type."));

builder.Build<paddle::dialect::FetchOp>(
temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);

std::stringstream ss;
ss << std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::string output_var_name =
"constant_folding@_" + ss.str() + std::to_string((*counter_)++);

builder.Build<pir::ShadowOutputOp>(temp_op->result(0), output_var_name);
// }

return program;
return output_var_name;
}

private:
size_t* counter_{nullptr};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
inline static size_t suffix_{0};
inline static paddle::framework::interpreter::ExecutionConfig exe_config_{};
paddle::framework::interpreter::ExecutionConfig* exe_config_{nullptr};
std::vector<std::string>* deleted_vars_{nullptr};
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {}
ConstantFoldingPass(const phi::Place& place, paddle::framework::Scope* scope)
: pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) {
PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
}

private:
bool Initialize(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(context, &scope_);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}

void Run(pir::Operation* op) override {
size_t op_nums = op->GetParentProgram()->block()->size();
pir::GreedyRewriteConfig cfg;
cfg.use_top_down_traversal = true;
cfg.max_iterations = 10;
pir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);

// delete old parameter var
scope_->EraseVars(deleted_vars_);
LOG(INFO) << " ------ constant_folding_pass done: [" << counter_ << "/"
<< op_nums << "]";
}

bool CanApplyOn(pir::Operation* op) const override {
// TODO(liuyuanle): remove op->isa<::pir::ModuleOp>()
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
}

private:
size_t counter_{0};
phi::Place place_;
paddle::framework::Scope* scope_{nullptr};
paddle::framework::interpreter::ExecutionConfig exe_config_{};
std::vector<std::string> deleted_vars_;

pir::FrozenRewritePatternSet patterns_;
paddle::framework::Scope scope_;
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ConstantFoldingPass>(place, scope);
}

} // namespace pir
10 changes: 9 additions & 1 deletion paddle/fluid/pir/transforms/constant_folding_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
#pragma once

#include <memory>
#include "paddle/phi/common/place.h"
#include "paddle/pir/core/dll_decl.h"

namespace paddle {
namespace framework {
class Scope;
}
} // namespace paddle

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();
IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope);

} // namespace pir
9 changes: 5 additions & 4 deletions paddle/fluid/pir/transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
Expand All @@ -35,16 +36,16 @@ class DeadCodeEliminationPattern : public pir::RewritePattern {
}

bool Match(pir::Operation* op) const override {
if (op->isa<paddle::dialect::FetchOp>() ||
op->isa<paddle::dialect::ShadowOutputOp>())
if (op->isa<paddle::dialect::FetchOp>() || op->isa<pir::ShadowOutputOp>() ||
op->isa<pir::YieldOp>()) {
return false;

}
return op->use_empty();
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
if (op->dyn_cast<pir::GetParameterOp>()) {
if (op->isa<pir::GetParameterOp>()) {
// Delete parameter from program.
pir::GetParameterOp get_parameter_op =
op->dyn_cast<pir::GetParameterOp>();
Expand Down
Loading