From 4742d634dec35c62af9c399e632115d86b273f25 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Fri, 13 Oct 2023 21:55:22 +0800 Subject: [PATCH] [PIR] Reconstruct the Verify system (#58052) * refine verify of if op * fix * fix * fix * refine * fix * fix * fix * fix --- .../hlir/dialect/operator/ir/manual_op.cc | 2 +- .../cinn/hlir/dialect/operator/ir/manual_op.h | 2 +- .../hlir/dialect/runtime/ir/jit_kernel_op.cc | 2 +- .../hlir/dialect/runtime/ir/jit_kernel_op.h | 2 +- .../translator/program_translator.cc | 2 + .../fluid/pir/dialect/kernel/ir/kernel_op.cc | 4 +- .../fluid/pir/dialect/kernel/ir/kernel_op.h | 4 +- .../fluid/pir/dialect/op_generator/op_gen.py | 2 +- .../pir/dialect/op_generator/op_verify_gen.py | 4 +- .../dialect/operator/ir/control_flow_op.cc | 70 ++++++++++++++++++- .../pir/dialect/operator/ir/control_flow_op.h | 6 +- .../pir/dialect/operator/ir/manual_op.cc | 12 ++-- .../fluid/pir/dialect/operator/ir/manual_op.h | 12 ++-- paddle/pir/core/builtin_op.cc | 16 ++--- paddle/pir/core/builtin_op.h | 16 ++--- paddle/pir/core/dialect.h | 3 +- paddle/pir/core/ir_context.cc | 6 +- paddle/pir/core/ir_context.h | 3 +- paddle/pir/core/op_base.h | 17 ++++- paddle/pir/core/op_info.cc | 13 +++- paddle/pir/core/op_info.h | 4 ++ paddle/pir/core/op_info_impl.cc | 6 +- paddle/pir/core/op_info_impl.h | 17 +++-- paddle/pir/core/operation.cc | 2 +- paddle/pir/dialect/control_flow/ir/cf_ops.h | 2 +- paddle/pir/dialect/shape/ir/shape_op.h | 12 ++-- test/cpp/pir/core/ir_infershape_test.cc | 2 +- test/cpp/pir/core/ir_program_test.cc | 4 +- test/cpp/pir/pass/pass_manager_test.cc | 4 +- .../pattern_rewrite/pattern_rewrite_test.cc | 8 +-- test/cpp/pir/tools/test_op.cc | 4 +- test/cpp/pir/tools/test_op.h | 36 +++++----- 32 files changed, 205 insertions(+), 94 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index db81b53a16f96..3a4ebb63679f3 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -44,7 +44,7 @@ std::vector GroupOp::ops() { inner_block->end()); } -void GroupOp::Verify() {} +void GroupOp::VerifySig() {} void GroupOp::Print(pir::IrPrinter &printer) { auto &os = printer.os; diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 9d469d9f776c4..39d433790be78 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -36,7 +36,7 @@ class GroupOp : public pir::Op { pir::Block *block(); std::vector ops(); - void Verify(); + void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT }; diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc index ed3d4a4045c59..c98eb564b9735 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.cc @@ -22,7 +22,7 @@ namespace dialect { const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName}; -void JitKernelOp::Verify() { +void JitKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp."; auto& attributes = this->attributes(); diff --git a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h index f410e4d46c021..62adcf2b1c7f1 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h +++ b/paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h @@ -51,7 +51,7 @@ class JitKernelOp : public ::pir::Op { hlir::framework::Instruction* instruction(); - void Verify(); + void VerifySig(); }; } // namespace dialect diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 3cb2517642229..eb5b8c962e880 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -367,6 +367,8 @@ pir::Operation* ProgramTranslator::TranslateCondIfOperation( true); } VLOG(4) << "[general op][conditional_block] IfOp false block translate end."; + + operation->Verify(); VLOG(4) << "[general op][conditional_block] IfOp translate end."; return operation; } diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 62c1129f84620..8ad46bc8906ad 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -25,7 +25,7 @@ const char* PhiKernelOp::attributes_name[attributes_num] = { // NOLINT "kernel_name", "kernel_key"}; -void PhiKernelOp::Verify() { +void PhiKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp."; auto& attributes = this->attributes(); @@ -64,7 +64,7 @@ const char* LegacyKernelOp::attributes_name[attributes_num] = { // NOLINT "kernel_name", "kernel_key"}; -void LegacyKernelOp::Verify() { +void LegacyKernelOp::VerifySig() { VLOG(4) << "Verifying inputs, outputs and attributes for: LegacyKernelOp."; auto& attributes = this->attributes(); diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h index 8a18959665e0c..a96aa5732d580 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.h @@ -29,7 +29,7 @@ class PhiKernelOp : public pir::Op { std::string op_name(); std::string kernel_name(); phi::KernelKey kernel_key(); - void Verify(); + void VerifySig(); }; class LegacyKernelOp : public pir::Op { @@ -41,7 +41,7 @@ class LegacyKernelOp : public pir::Op { std::string op_name(); std::string kernel_name(); phi::KernelKey kernel_key(); - void Verify(); + void VerifySig(); }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index d9dd1cc879a23..64caafc544892 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -99,7 +99,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {build_mutable_attr_is_input} {build_attr_num_over_1} {build_mutable_attr_is_input_attr_num_over_1} - void Verify(); + void VerifySig(); {get_inputs_and_outputs} {exclusive_interface} }}; diff --git a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 1b8c82b27d90b..3a2515f278915 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -14,7 +14,7 @@ # verify OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify() {{ +void {op_name}::VerifySig() {{ VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}."; VLOG(4) << "Verifying inputs:"; {{ @@ -36,7 +36,7 @@ """ GRAD_OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify() {{}} +void {op_name}::VerifySig() {{}} """ INPUT_TYPE_CHECK_TEMPLATE = """ diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index c47dd600bace4..3afbec0661662 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -19,6 +19,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/operation_utils.h" #include "paddle/pir/dialect/control_flow/ir/cf_ops.h" @@ -109,7 +110,74 @@ void IfOp::Print(pir::IrPrinter &printer) { } os << "\n }"; } -void IfOp::Verify() {} +void IfOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: IfOp."; + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + phi::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + + if ((*this)->operand_source(0).type().isa()) { + PADDLE_ENFORCE( + (*this) + ->operand_source(0) + .type() + .dyn_cast() + .dtype() + .isa(), + phi::errors::PreconditionNotMet( + "Type validation failed for the 1th input, it should be a " + "bool DenseTensorType.")); + } + + PADDLE_ENFORCE_EQ((*this)->num_regions(), + 2u, + phi::errors::PreconditionNotMet( + "The size %d of regions must be equal to 2.", + (*this)->num_regions())); +} + +void IfOp::VerifyRegion() { + VLOG(4) << "Start Verifying sub regions for: IfOp."; + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + 1u, + phi::errors::PreconditionNotMet("The size %d of true_region must be 1.", + (*this)->region(0).size())); + + if ((*this)->num_results() != 0) { + PADDLE_ENFORCE_EQ( + (*this)->region(0).size(), + (*this)->region(1).size(), + phi::errors::PreconditionNotMet("The size %d of true_region must be " + "equal to the size %d of false_region.", + (*this)->region(0).size(), + (*this)->region(1).size())); + + auto *true_last_op = (*this)->region(0).front()->back(); + auto *false_last_op = (*this)->region(1).front()->back(); + PADDLE_ENFORCE_EQ(true_last_op->isa(), + true, + phi::errors::PreconditionNotMet( + "The last of true block must be YieldOp")); + PADDLE_ENFORCE_EQ(true_last_op->num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of true block op's input must be " + "equal to IfOp's outputs num.")); + PADDLE_ENFORCE_EQ(false_last_op->isa(), + true, + phi::errors::PreconditionNotMet( + "The last of false block must be YieldOp")); + PADDLE_ENFORCE_EQ(false_last_op->num_operands(), + (*this)->num_results(), + phi::errors::PreconditionNotMet( + "The size of last of false block op's input must be " + "equal to IfOp's outputs num.")); + } +} void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 48571d7e501ef..b8a92415394dd 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -41,7 +41,8 @@ class IfOp : public pir::Op { pir::Block *true_block(); pir::Block *false_block(); void Print(pir::IrPrinter &printer); // NOLINT - void Verify(); + void VerifySig(); + void VerifyRegion(); }; class WhileOp : public pir::Op { @@ -57,7 +58,8 @@ class WhileOp : public pir::Op { pir::Block *cond_block(); pir::Block *body_block(); void Print(pir::IrPrinter &printer); // NOLINT - void Verify() {} + void VerifySig() {} + void VerifyRegion() {} }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index eb5f1f5a53670..0f636e01e19a3 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -50,7 +50,7 @@ OpInfoTuple AddNOp::GetOpInfo() { return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n"); } -void AddNOp::Verify() { +void AddNOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddNOp."; VLOG(4) << "Verifying inputs:"; { @@ -222,7 +222,7 @@ void AddN_Op::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void AddN_Op::Verify() { +void AddN_Op::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: AddN_Op."; VLOG(4) << "Verifying inputs:"; { @@ -345,7 +345,7 @@ void AddNWithKernelOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void AddNWithKernelOp::Verify() { +void AddNWithKernelOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: " "AddNWithKernelOp."; VLOG(4) << "Verifying inputs:"; @@ -561,7 +561,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void FusedGemmEpilogueOp::Verify() { +void FusedGemmEpilogueOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: " "FusedGemmEpilogueOp."; VLOG(4) << "Verifying inputs:"; @@ -833,7 +833,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void FusedGemmEpilogueGradOp::Verify() {} +void FusedGemmEpilogueGradOp::VerifySig() {} void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::FusedGemmEpilogueGradInferMeta); @@ -983,7 +983,7 @@ void SplitGradOp::Build(pir::Builder &builder, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } -void SplitGradOp::Verify() { +void SplitGradOp::VerifySig() { VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp."; VLOG(4) << "Verifying inputs:"; { diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index c6fc7cb32b316..317ce64feea08 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -45,7 +45,7 @@ class AddNOp : public pir::Op { pir::Value out_grad_, pir::Value axis_); - void Verify(); + void VerifySig(); pir::Value out_grad() { return operand_source(0); } pir::Value axis() { return operand_source(1); } pir::OpResult x_grad() { return result(0); } diff --git a/paddle/pir/core/builtin_op.cc b/paddle/pir/core/builtin_op.cc index ada969399550a..b092cb7bed266 100644 --- a/paddle/pir/core/builtin_op.cc +++ b/paddle/pir/core/builtin_op.cc @@ -82,7 +82,7 @@ void ModuleOp::Destroy() { } } -void ModuleOp::Verify() const { +void ModuleOp::VerifySig() const { VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; // Verify inputs: IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); @@ -118,7 +118,7 @@ void GetParameterOp::PassStopGradients(OperationArgument &argument) { pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } -void GetParameterOp::Verify() const { +void GetParameterOp::VerifySig() const { VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; // Verify inputs: IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); @@ -144,7 +144,7 @@ void SetParameterOp::Build(Builder &builder, // NOLINT argument.AddAttribute(attributes_name[0], pir::StrAttribute::get(builder.ir_context(), name)); } -void SetParameterOp::Verify() const { +void SetParameterOp::VerifySig() const { VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; // Verify inputs: IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1."); @@ -170,7 +170,7 @@ void ShadowOutputOp::Build(Builder &builder, // NOLINT argument.AddAttribute(attributes_name[0], pir::StrAttribute::get(builder.ir_context(), name)); } -void ShadowOutputOp::Verify() const { +void ShadowOutputOp::VerifySig() const { VLOG(4) << "Verifying inputs, outputs and attributes for: ShadowOutputOp."; // Verify inputs: IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1."); @@ -198,7 +198,7 @@ void CombineOp::Build(Builder &builder, PassStopGradientsDefaultly(argument); } -void CombineOp::Verify() const { +void CombineOp::VerifySig() const { // outputs.size() == 1 IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); @@ -260,7 +260,7 @@ void SliceOp::PassStopGradients(OperationArgument &argument, int index) { pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } -void SliceOp::Verify() const { +void SliceOp::VerifySig() const { // inputs.size() == 1 auto input_size = num_operands(); IR_ENFORCE( @@ -364,7 +364,7 @@ void SplitOp::PassStopGradients(OperationArgument &argument) { pir::ArrayAttribute::get(pir::IrContext::Instance(), outs_stop_gradient)); } -void SplitOp::Verify() const { +void SplitOp::VerifySig() const { // inputs.size() == 1 IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1."); @@ -393,7 +393,7 @@ void ConstantOp::Build(Builder &builder, argument.output_types.push_back(output_type); } -void ConstantOp::Verify() const { +void ConstantOp::VerifySig() const { IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(attributes().count("value") > 0, "must has value attribute"); diff --git a/paddle/pir/core/builtin_op.h b/paddle/pir/core/builtin_op.h index 93a9e0786dd9b..19ca96b052692 100644 --- a/paddle/pir/core/builtin_op.h +++ b/paddle/pir/core/builtin_op.h @@ -31,7 +31,7 @@ class IR_API ModuleOp : public pir::Op { static const char *name() { return "builtin.module"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - void Verify() const; + void VerifySig() const; Program *program(); Block *block(); @@ -56,7 +56,7 @@ class IR_API GetParameterOp : public pir::Op { OperationArgument &argument, // NOLINT const std::string &name, Type type); - void Verify() const; + void VerifySig() const; private: static void PassStopGradients(OperationArgument &argument); // NOLINT @@ -76,7 +76,7 @@ class IR_API SetParameterOp : public pir::Op { OperationArgument &argument, // NOLINT Value parameter, const std::string &name); - void Verify() const; + void VerifySig() const; }; /// @@ -93,7 +93,7 @@ class IR_API ShadowOutputOp : public pir::Op { OperationArgument &argument, // NOLINT Value parameter, const std::string &name); - void Verify() const; + void VerifySig() const; }; /// @@ -113,7 +113,7 @@ class IR_API CombineOp : public pir::Op { OperationArgument &argument, // NOLINT const std::vector &inputs); - void Verify() const; + void VerifySig() const; std::vector inputs() { std::vector inputs; for (uint32_t idx = 0; idx < num_operands(); idx++) { @@ -142,7 +142,7 @@ class IR_API SliceOp : public pir::Op { Value input, int index); - void Verify() const; + void VerifySig() const; pir::Value input() { return operand_source(0); } private: @@ -167,7 +167,7 @@ class IR_API SplitOp : public pir::Op { OperationArgument &argument, // NOLINT Value input); - void Verify() const; + void VerifySig() const; pir::Value input() { return operand_source(0); } std::vector outputs() { std::vector res; @@ -203,7 +203,7 @@ class IR_API ConstantOp : public Op { Attribute value, Type output_type); - void Verify() const; + void VerifySig() const; Attribute value() const; }; diff --git a/paddle/pir/core/dialect.h b/paddle/pir/core/dialect.h index 07debaf196041..8c66f3c1d6a15 100644 --- a/paddle/pir/core/dialect.h +++ b/paddle/pir/core/dialect.h @@ -100,7 +100,8 @@ class IR_API Dialect { ConcreteOp::GetTraitSet(), ConcreteOp::attributes_num, ConcreteOp::attributes_name, - ConcreteOp::VerifyInvariants); + ConcreteOp::VerifySigInvariants, + ConcreteOp::VerifyRegionInvariants); } void RegisterOp(const std::string &name, OpInfoImpl *op_info); diff --git a/paddle/pir/core/ir_context.cc b/paddle/pir/core/ir_context.cc index d469b431bf923..1ebd9e4f0c642 100644 --- a/paddle/pir/core/ir_context.cc +++ b/paddle/pir/core/ir_context.cc @@ -292,7 +292,8 @@ void IrContext::RegisterOpInfo(Dialect *dialect, const std::vector &trait_set, size_t attributes_num, const char **attributes_name, - VerifyPtr verify) { + VerifyPtr verify_sig, + VerifyPtr verify_region) { if (impl().IsOpInfoRegistered(name)) { LOG(WARNING) << name << " op already registered."; } else { @@ -303,7 +304,8 @@ void IrContext::RegisterOpInfo(Dialect *dialect, trait_set, attributes_num, attributes_name, - verify); + verify_sig, + verify_region); impl().RegisterOpInfo(name, info); } } diff --git a/paddle/pir/core/ir_context.h b/paddle/pir/core/ir_context.h index d459f91524229..c20a0d7bba292 100644 --- a/paddle/pir/core/ir_context.h +++ b/paddle/pir/core/ir_context.h @@ -113,7 +113,8 @@ class IR_API IrContext { const std::vector &trait_set, size_t attributes_num, const char **attributes_name, - void (*verify)(Operation *)); + void (*verify_sig)(Operation *), + void (*verify_region)(Operation *)); /// /// \brief Get registered operaiton infomation. diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index 8e67a392c51cf..f0710ff5ec629 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -63,6 +63,10 @@ class IR_API OpBase { return operation()->attribute(name); } + void VerifySig() {} + + void VerifyRegion() {} + private: Operation *operation_; // Not owned }; @@ -162,14 +166,21 @@ class Op : public OpBase { class EmptyOp : public Op {}; return sizeof(ConcreteOp) == sizeof(EmptyOp); } - // Implementation of `VerifyInvariantsFn` OperationName hook. - static void VerifyInvariants(Operation *op) { + + // Implementation of `VerifySigInvariantsFn` OperationName hook. + static void VerifySigInvariants(Operation *op) { static_assert(HasNoDataMembers(), "Op class shouldn't define new data members"); - op->dyn_cast().Verify(); + op->dyn_cast().VerifySig(); (void)std::initializer_list{ 0, (VerifyTraitOrInterface::call(op), 0)...}; } + + static void VerifyRegionInvariants(Operation *op) { + static_assert(HasNoDataMembers(), + "Op class shouldn't define new data members"); + op->dyn_cast().VerifyRegion(); + } }; } // namespace pir diff --git a/paddle/pir/core/op_info.cc b/paddle/pir/core/op_info.cc index b018bec30448d..499bfda0e69e7 100644 --- a/paddle/pir/core/op_info.cc +++ b/paddle/pir/core/op_info.cc @@ -35,7 +35,18 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } -void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); } +void OpInfo::Verify(Operation *operation) const { + VerifySig(operation); + VerifyRegion(operation); +} + +void OpInfo::VerifySig(Operation *operation) const { + impl_->VerifySig()(operation); +} + +void OpInfo::VerifyRegion(Operation *operation) const { + impl_->VerifyRegion()(operation); +} void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; diff --git a/paddle/pir/core/op_info.h b/paddle/pir/core/op_info.h index 23fc5bfe1b9eb..a7416c146a90e 100644 --- a/paddle/pir/core/op_info.h +++ b/paddle/pir/core/op_info.h @@ -54,6 +54,10 @@ class IR_API OpInfo { void Verify(Operation *) const; + void VerifySig(Operation *) const; + + void VerifyRegion(Operation *) const; + template bool HasTrait() const { return HasTrait(TypeId::get()); diff --git a/paddle/pir/core/op_info_impl.cc b/paddle/pir/core/op_info_impl.cc index 12245f12a652a..33320f1d52367 100644 --- a/paddle/pir/core/op_info_impl.cc +++ b/paddle/pir/core/op_info_impl.cc @@ -24,7 +24,8 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], // NOLINT - VerifyPtr verify) { + VerifyPtr verify_sig, + VerifyPtr verify_region) { // (1) Malloc memory for interfaces, traits, opinfo_impl. size_t interfaces_num = interface_map.size(); size_t traits_num = trait_set.size(); @@ -59,7 +60,8 @@ OpInfo OpInfoImpl::Create(Dialect *dialect, traits_num, attributes_num, attributes_name, - verify)); + verify_sig, + verify_region)); return op_info; } void OpInfoImpl::Destroy(OpInfo info) { diff --git a/paddle/pir/core/op_info_impl.h b/paddle/pir/core/op_info_impl.h index cc63a52d40064..a08084682f1d0 100644 --- a/paddle/pir/core/op_info_impl.h +++ b/paddle/pir/core/op_info_impl.h @@ -42,14 +42,17 @@ class OpInfoImpl { const std::vector &trait_set, size_t attributes_num, const char *attributes_name[], - VerifyPtr verify); + VerifyPtr verify_sig, + VerifyPtr verify_region); static void Destroy(OpInfo info); TypeId id() const { return op_id_; } Dialect *dialect() const { return dialect_; } - VerifyPtr verify() const { return verify_; } + VerifyPtr VerifySig() const { return verify_sig_; } + + VerifyPtr VerifyRegion() const { return verify_region_; } IrContext *ir_context() const; @@ -76,7 +79,8 @@ class OpInfoImpl { uint32_t num_traits, uint32_t num_attributes, const char **p_attributes, - VerifyPtr verify) + VerifyPtr verify_sig, + VerifyPtr verify_region) : dialect_(dialect), op_id_(op_id), op_name_(op_name), @@ -84,7 +88,8 @@ class OpInfoImpl { num_traits_(num_traits), num_attributes_(num_attributes), p_attributes_(p_attributes), - verify_(verify) {} + verify_sig_(verify_sig), + verify_region_(verify_region) {} void Destroy(); /// The dialect of this Op belong to. @@ -108,7 +113,9 @@ class OpInfoImpl { /// Attributes array address. const char **p_attributes_{nullptr}; - VerifyPtr verify_{nullptr}; + VerifyPtr verify_sig_{nullptr}; + + VerifyPtr verify_region_{nullptr}; }; } // namespace pir diff --git a/paddle/pir/core/operation.cc b/paddle/pir/core/operation.cc index 6a13963c93587..0dedeafc9ae71 100644 --- a/paddle/pir/core/operation.cc +++ b/paddle/pir/core/operation.cc @@ -123,7 +123,7 @@ Operation *Operation::Create(const std::vector &inputs, // 0. Verify if (op_info) { - op_info.Verify(op); + op_info.VerifySig(op); } return op; } diff --git a/paddle/pir/dialect/control_flow/ir/cf_ops.h b/paddle/pir/dialect/control_flow/ir/cf_ops.h index fe3e965fede8f..7d669c0b648ea 100644 --- a/paddle/pir/dialect/control_flow/ir/cf_ops.h +++ b/paddle/pir/dialect/control_flow/ir/cf_ops.h @@ -28,7 +28,7 @@ class IR_API YieldOp : public Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT const std::vector &Value); - void Verify() {} + void VerifySig() {} }; } // namespace pir diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index c8ec2df012341..c838624d2566d 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -71,7 +71,7 @@ class IR_API SymbolicDim : public Op { return "kSymbolicDimAttr"; } - void Verify() {} + void VerifySig() {} }; class IR_API DimOp : public Op { @@ -89,7 +89,7 @@ class IR_API DimOp : public Op { const std::string getName(); void setName(std::string attrValue); OpResult out() { return result(0); } - void Verify() {} + void VerifySig() {} }; class IR_API TieProductEqualOp : public Op { @@ -111,7 +111,7 @@ class IR_API TieProductEqualOp : public Op { const std::vector &rhs); std::vector lhs(); std::vector rhs(); - void Verify() {} + void VerifySig() {} }; class IR_API TieShapeOp : public Op { @@ -132,7 +132,7 @@ class IR_API TieShapeOp : public Op { const std::vector &dims); Value value(); std::vector dims(); - void Verify() {} + void VerifySig() {} }; class IR_API FuncOp : public Op { @@ -147,7 +147,7 @@ class IR_API FuncOp : public Op { OperationArgument &argument); // NOLINT void Print(IrPrinter &printer); // NOLINT Block *block(); - void Verify() {} + void VerifySig() {} }; class IR_API TensorDimOp : public Op { @@ -169,7 +169,7 @@ class IR_API TensorDimOp : public Op { Value index(); Value source(); OpResult out() { return result(0); } - void Verify() {} + void VerifySig() {} }; } // namespace pir::dialect diff --git a/test/cpp/pir/core/ir_infershape_test.cc b/test/cpp/pir/core/ir_infershape_test.cc index 720d4b238d5eb..09d3a2fe9b6b1 100644 --- a/test/cpp/pir/core/ir_infershape_test.cc +++ b/test/cpp/pir/core/ir_infershape_test.cc @@ -45,7 +45,7 @@ class OperationTest static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; // NOLINT - static void Verify() {} + static void VerifySig() {} static void InferMeta(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::CreateInferMeta); fn(infer_meta); diff --git a/test/cpp/pir/core/ir_program_test.cc b/test/cpp/pir/core/ir_program_test.cc index 85f608aa117a2..7ae348d004f53 100644 --- a/test/cpp/pir/core/ir_program_test.cc +++ b/test/cpp/pir/core/ir_program_test.cc @@ -41,14 +41,14 @@ class AddOp : public pir::Op { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - void Verify(); + void VerifySig(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand, pir::Type sum_type); }; -void AddOp::Verify() { +void AddOp::VerifySig() { if (num_operands() != 2) { throw("The size of inputs must be equal to 2."); } diff --git a/test/cpp/pir/pass/pass_manager_test.cc b/test/cpp/pir/pass/pass_manager_test.cc index e83764226ebd1..03e7d88d484bc 100644 --- a/test/cpp/pir/pass/pass_manager_test.cc +++ b/test/cpp/pir/pass/pass_manager_test.cc @@ -69,14 +69,14 @@ class AddOp : public pir::Op { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - void Verify(); + void VerifySig(); static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::OpResult l_operand, pir::OpResult r_operand, pir::Type sum_type); }; -void AddOp::Verify() { +void AddOp::VerifySig() { if (num_operands() != 2) { throw("The size of inputs must be equal to 2."); } diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index adfe431a6be2b..18644c08e21b7 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -79,11 +79,11 @@ class Operation1 : public pir::Op { static const char *name() { return "test.Operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; // NOLINT - void Verify(); + void VerifySig(); static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; -void Operation1::Verify() { +void Operation1::VerifySig() { auto &attributes = this->attributes(); if (attributes.count("op2_attr1") == 0 || (!attributes.at("op2_attr1").isa())) { @@ -390,7 +390,7 @@ class Conv2dFusionOpTest : public pir::Opnum_successors() == 1u, "successors number must equal to 1."); IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr"); @@ -45,7 +45,7 @@ void Operation1::Build(pir::Builder &builder, // NOLINT argument.AddOutput(builder.float32_type()); argument.AddAttributes(attributes); } -void Operation1::Verify() const { +void Operation1::VerifySig() const { auto &attributes = this->attributes(); if (attributes.count("op1_attr1") == 0 || !attributes.at("op1_attr1").isa()) { diff --git a/test/cpp/pir/tools/test_op.h b/test/cpp/pir/tools/test_op.h index 98f01db37614d..175a9268390e9 100644 --- a/test/cpp/pir/tools/test_op.h +++ b/test/cpp/pir/tools/test_op.h @@ -34,7 +34,7 @@ class RegionOp : public pir::Op { static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT - void Verify() const {} + void VerifySig() const {} }; /// @@ -50,7 +50,7 @@ class BranchOp : public pir::Op { pir::OperationArgument &argument, // NOLINT const std::vector &target_operands, pir::Block *target); - void Verify() const; + void VerifySig() const; }; // Define case op1. @@ -62,7 +62,7 @@ class Operation1 : public pir::Op { static const char *attributes_name[attributes_num]; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT - void Verify() const; + void VerifySig() const; }; // Define op2. @@ -75,7 +75,7 @@ class Operation2 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; @@ -98,7 +98,7 @@ class TraitExampleOp pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsShapeTraitOp1. @@ -111,7 +111,7 @@ class SameOperandsShapeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsShapeTraitOp2. @@ -127,7 +127,7 @@ class SameOperandsShapeTraitOp2 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp1. @@ -143,7 +143,7 @@ class SameOperandsAndResultShapeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp2. @@ -161,7 +161,7 @@ class SameOperandsAndResultShapeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultShapeTraitOp3. @@ -180,7 +180,7 @@ class SameOperandsAndResultShapeTraitOp3 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsElementTypeTraitOp1. @@ -194,7 +194,7 @@ class SameOperandsElementTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsElementTypeTraitOp2. @@ -211,7 +211,7 @@ class SameOperandsElementTypeTraitOp2 pir::Value l_operand, pir::Value r_operand, pir::Type out_type); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp1. @@ -227,7 +227,7 @@ class SameOperandsAndResultElementTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp2. @@ -245,7 +245,7 @@ class SameOperandsAndResultElementTypeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultElementTypeTraitOp3. @@ -265,7 +265,7 @@ class SameOperandsAndResultElementTypeTraitOp3 pir::Value r_operand, pir::Type out_type1, pir::Type out_type2); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp1. @@ -279,7 +279,7 @@ class SameOperandsAndResultTypeTraitOp1 static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument) {} // NOLINT - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp2. @@ -295,7 +295,7 @@ class SameOperandsAndResultTypeTraitOp2 pir::OperationArgument &argument, // NOLINT pir::Value l_operand, pir::Value r_operand); - void Verify() const {} + void VerifySig() const {} }; // Define SameOperandsAndResultTypeTraitOp3. @@ -315,7 +315,7 @@ class SameOperandsAndResultTypeTraitOp3 pir::Type out_type1, pir::Type out_type2); - void Verify() const {} + void VerifySig() const {} }; } // namespace test