diff --git a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py index d9747b47e8747..429c34529cd3d 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_build_gen.py @@ -13,6 +13,10 @@ # limitations under the License. # generator build function +_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'} + +_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'} + OP_BUILD_TEMPLATE = """ void {op_name}::Build({build_args}) {{ {get_attributes} @@ -273,6 +277,7 @@ def GenBuildAttributes( def GenBuildOutputs( + op_class_name, op_input_name_list, op_input_type_list, op_mutable_attribute_name_list, @@ -318,6 +323,40 @@ def GenBuildOutputs( CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" + CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; + if ({name}_.owner()->info().id() == ir::TypeId::get()) {{ + {name} = std::move(phi::IntArray({name}_.owner() + ->dyn_cast() + .attributes() + .at("value") + .dyn_cast() + .data() + .GetData())); + }} + else {{ + PADDLE_ENFORCE( + {name}_.type().isa(), + phi::errors::PreconditionNotMet("section Type should be VectorType.")); + size_t {name}_size = {name}_.type().dyn_cast().size(); + {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name}.SetFromTensor(true); + }}\n""" + + CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; + if ({name}_.owner()->info().id() == ir::TypeId::get()) {{ + {name} = std::move(phi::Scalar({name}_.owner() + ->dyn_cast() + .attributes() + .at("value") + .dyn_cast() + .data() + .to())); + }} + else {{ + {name} = std::move(phi::Scalar(-1)); + {name}.SetFromTensor(true); + }}\n""" + CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; phi::MetaTensor meta_{name}(&dense_{name}); """ @@ -350,19 +389,30 @@ def GenBuildOutputs( attr_dtype = op_mutable_attribute_type_list[idx] # int_array if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": - build_output_str += ( - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( + if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE: + build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) - ) + else: + build_output_str += ( + CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx] + ) + ) # scalar elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": - build_output_str += ( - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( + if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE: + build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1], ) - ) + else: + build_output_str += ( + CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + dtype=attr_dtype[1], + ) + ) # string elif attr_dtype[0] == "ir::StrAttribute": build_output_str += "" @@ -423,9 +473,19 @@ def GenBuildOutputs( CREATE_INFER_META_FUNC_TEMPLATE = """ phi::{func}({args}); """ - build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( - func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) - ) + CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """ + phi::{func}({args}, phi::MetaConfig(false, false)); +""" + if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG: + build_output_str += ( + CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) + ) + else: + build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format( + func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) + ) # use dense_{name} or vec_dense_{name} to create Outputs type build_output_str += "\n std::vector argument_outputs;" @@ -530,6 +590,7 @@ def gen_build_func_str( op_non_mutable_attribute_type_list, ) build_outputs_str = GenBuildOutputs( + op_class_name, op_input_name_list, op_input_type_list, op_mutable_attribute_name_list,