Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] pir onednn add conv_bias_pass #61137

Merged
23 changes: 23 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#include "paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.h"
#endif

#ifdef PADDLE_WITH_ONNXRUNTIME
Expand Down Expand Up @@ -835,6 +836,28 @@ bool AnalysisPredictor::PrepareExecutor() {
gpu_pm.EnableIRPrinting();
}
gpu_pm.Run(pir_program_.get());
#ifdef PADDLE_WITH_DNNL
} else if (config_.mkldnn_enabled()) {
::pir::PassManager mkldnn_pm(::pir::IrContext::Instance(), 2);

mkldnn_pm.AddPass(::pir::CreateConv2dBiasFusePass());

auto constant_folding_pass = ::pir::CreateConstantFoldingPass();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_);

mkldnn_pm.AddPass(std::move(constant_folding_pass));
mkldnn_pm.AddPass(::pir::CreateDeadCodeEliminationPass());
mkldnn_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
//----------------------------------------------------------------------------------------------//
if (!config_.glog_info_disabled()) {
mkldnn_pm.EnablePrintStatistics();
}
if (config_.ir_debug_) {
mkldnn_pm.EnableIRPrinting();
}
mkldnn_pm.Run(pir_program_.get());
#endif
} else {
::pir::PassManager cpu_pm(::pir::IrContext::Instance(), 2);

Expand Down
71 changes: 63 additions & 8 deletions paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,73 @@
}});
"""

Dialect2NameSpaceMap = {"pd_op": "paddle::dialect", "cinn_op": "cinn::dialect"}
Dialect2NameSpaceMap = {
"pd_op": "paddle::dialect",
"cinn_op": "cinn::dialect",
"onednn_op": "paddle::onednn::dialect",
}
Dialect2OpHeaderMap = {
"pd_op": "#include \"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"",
"cinn_op": "#include \"paddle/cinn/hlir/dialect/operator/ir/cinn_op.h\"",
"onednn_op": "#include \"paddle/fluid/pir/dialect/operator/ir/onednn_op.h\"",
}


class OpCreatorCodeGen:
def __init__(self, op_yaml_files, op_compat_yaml_file, dialect_name):
self.op_info_items = self.parse_yaml(op_yaml_files, op_compat_yaml_file)
def __init__(
self,
op_yaml_files,
op_compat_yaml_file,
dialect_name,
onednn_yaml_file=None,
ops_onednn_extra_yaml_file=None,
):
self.op_info_items = self.parse_yaml(
op_yaml_files,
op_compat_yaml_file,
onednn_yaml_file,
ops_onednn_extra_yaml_file,
)
self.dialect_name = dialect_name

def parse_yaml(self, op_yaml_files, op_compat_yaml_file):
def parse_yaml(
self,
op_yaml_files,
op_compat_yaml_file,
onednn_yaml_file=None,
ops_onednn_extra_yaml_file=None,
):
op_compat_parser = OpCompatParser(op_compat_yaml_file)

op_yaml_items = []
for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:

if dialect_name == "onednn_op":
with open(ops_onednn_extra_yaml_file, "r") as f:
ops_onednn_extra = yaml.safe_load(f)
ops_onednn_extra_set = set()
for op in ops_onednn_extra:
ops_onednn_extra_set.add(op['op'])
with open(onednn_yaml_file, "r") as f:
ops = yaml.safe_load(f)
op_yaml_items = op_yaml_items + ops
onednn_ops = []
for op in ops:
onednn_ops.append(op)
op_yaml_items = op_yaml_items + onednn_ops

for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
onednn_ops = []
for op in ops:
if op['name'] in ops_onednn_extra_set:
onednn_ops.append(op)
op_yaml_items = op_yaml_items + onednn_ops

else:
for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
op_yaml_items = op_yaml_items + ops

op_info_items = []
for op in op_yaml_items:
Expand Down Expand Up @@ -167,6 +214,8 @@ def ParseArguments():
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--dialect_name', type=str)
parser.add_argument('--op_creator_file', type=str)
parser.add_argument('--onednn_yaml_file', type=str)
parser.add_argument('--ops_onednn_extra_yaml_file', type=str)
return parser.parse_args()


Expand All @@ -176,8 +225,14 @@ def ParseArguments():
op_compat_yaml_file = args.op_compat_yaml_file
op_creator_file = args.op_creator_file
dialect_name = args.dialect_name
onednn_yaml_file = args.onednn_yaml_file
ops_onednn_extra_yaml_file = args.ops_onednn_extra_yaml_file

code_gen = OpCreatorCodeGen(
op_yaml_files, op_compat_yaml_file, dialect_name
op_yaml_files,
op_compat_yaml_file,
dialect_name,
onednn_yaml_file,
ops_onednn_extra_yaml_file,
)
code_gen.gen_cpp_file_code(op_creator_file)
41 changes: 40 additions & 1 deletion paddle/fluid/pir/drr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,46 @@ if(WITH_CINN AND NOT CINN_ONLY)

endif()

if(WITH_MKLDNN)
set(onednn_dialect_name onednn_op)
set(pir_op_onednn_yaml
${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir/generated/onednn.parsed.yaml
)
set(pd_ops_onednn_extra_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
)
set(onednn_op_creator_file
${PADDLE_BINARY_DIR}/paddle/fluid/pir/drr/onednn_op_factory_generated.cc)
set(onednn_op_creator_file_tmp ${onednn_op_creator_file}.tmp)

add_custom_command(
OUTPUT ${onednn_op_creator_file}
COMMAND
${PYTHON_EXECUTABLE} ${op_creator_gen_file} --op_yaml_files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onednn op的创建也需要op_yaml_files吗?这里的逻辑和53-69行的逻辑是不是在创建pd op的部分重复了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onednn op的创建也需要op_yaml_files。因为ops_onednn_extra.yaml中只有补充信息,主要的OP信息还是在原来的ops.yaml中,需要复用。

${op_yaml_files} --op_compat_yaml_file ${op_compat_yaml_file}
--dialect_name ${onednn_dialect_name} --op_creator_file
${onednn_op_creator_file_tmp} --onednn_yaml_file ${pir_op_onednn_yaml}
--ops_onednn_extra_yaml_file ${pd_ops_onednn_extra_yaml_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${onednn_op_creator_file_tmp}
${onednn_op_creator_file}
COMMENT "copy_if_different ${onednn_op_creator_file}"
DEPENDS ${op_creator_gen_file}
${op_forward_yaml_file1}
${op_backward_yaml_file1}
${op_compat_yaml_file}
${op_yaml_file3}
${op_yaml_file4}
op_dialect
VERBATIM)

endif()

set(DRR_SRCS ${DRR_SRCS} ${op_creator_file} ${CINN_SOURCE_FILE})

if(WITH_MKLDNN)
set(DRR_SRCS ${DRR_SRCS} ${onednn_op_creator_file})
endif()
cc_library(
drr
SRCS ${DRR_SRCS} ${op_creator_file} ${CINN_SOURCE_FILE}
SRCS ${DRR_SRCS}
DEPS op_dialect_vjp ${CINN_DEPS} pir)
7 changes: 7 additions & 0 deletions paddle/fluid/pir/drr/include/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ class ResultPattern {
});
}

Attribute VectorFloatAttr(const std::vector<float>& value) const {
return ComputeAttr(
[=](const MatchContext& match_ctx) -> std::vector<float> {
return value;
});
}

Attribute ComputeAttr(const AttrComputeFunc& attr_compute_func) const {
return ComputeAttribute(attr_compute_func);
}
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/drr/ir_operation_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class OperationFactory {
RegisterPdOpGeneratedOpCreator();
#ifdef PADDLE_WITH_CINN
RegisterCinnOpGeneratedOpCreator();
#endif
#ifdef PADDLE_WITH_DNNL
RegisterOnednnOpGeneratedOpCreator();
#endif
RegisterManualOpCreator();
}
Expand All @@ -69,6 +72,9 @@ class OperationFactory {
void RegisterPdOpGeneratedOpCreator();
#ifdef PADDLE_WITH_CINN
void RegisterCinnOpGeneratedOpCreator();
#endif
#ifdef PADDLE_WITH_DNNL
void RegisterOnednnOpGeneratedOpCreator();
#endif
std::unordered_map<std::string, operation_create_fn> op_creator_map;
};
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ if(NOT WITH_CINN)
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc)
endif()

if(NOT WITH_MKLDNN)
file(GLOB_RECURSE onednn_srcs "onednn/*.cc")
list(REMOVE_ITEM transforms_srcs ${onednn_srcs})
endif()

set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir
device_event_base)

Expand Down
Loading