Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Sep 4, 2023
1 parent e2af9d5 commit 000a8db
Showing 1 changed file with 70 additions and 9 deletions.
79 changes: 70 additions & 9 deletions paddle/fluid/ir/dialect/op_generator/op_build_gen.py
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

# generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}

OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
{get_attributes}
Expand Down Expand Up @@ -273,6 +277,7 @@ def GenBuildAttributes(


def GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
Expand Down Expand Up @@ -318,6 +323,40 @@ def GenBuildOutputs(
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""

CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullIntArrayOp>()) {{
{name} = std::move(phi::IntArray({name}_.owner()
->dyn_cast<paddle::dialect::FullIntArrayOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData()));
}}
else {{
PADDLE_ENFORCE(
{name}_.type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet("section Type should be VectorType."));
size_t {name}_size = {name}_.type().dyn_cast<ir::VectorType>().size();
{name} = std::move(phi::IntArray(std::vector<int64_t>({name}_size, -1)));
{name}.SetFromTensor(true);
}}\n"""

CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name};
if ({name}_.owner()->info().id() == ir::TypeId::get<paddle::dialect::FullOp>()) {{
{name} = std::move(phi::Scalar({name}_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>()));
}}
else {{
{name} = std::move(phi::Scalar(-1));
{name}.SetFromTensor(true);
}}\n"""

CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
"""
Expand Down Expand Up @@ -350,19 +389,30 @@ def GenBuildOutputs(
attr_dtype = op_mutable_attribute_type_list[idx]
# int_array
if attr_dtype[0] == "paddle::dialect::IntArrayAttribute":
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
else:
build_output_str += (
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx]
)
)
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
if op_class_name in _PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE:
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
else:
build_output_str += (
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
)
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
build_output_str += ""
Expand Down Expand Up @@ -423,9 +473,19 @@ def GenBuildOutputs(
CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args});
"""
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """
phi::{func}({args}, phi::MetaConfig(false, false));
"""
if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG:
build_output_str += (
CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
)
else:
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)

# use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;"
Expand Down Expand Up @@ -530,6 +590,7 @@ def gen_build_func_str(
op_non_mutable_attribute_type_list,
)
build_outputs_str = GenBuildOutputs(
op_class_name,
op_input_name_list,
op_input_type_list,
op_mutable_attribute_name_list,
Expand Down

0 comments on commit 000a8db

Please sign in to comment.