diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 36fe9e340fcd9..5e7d3e6d876cf 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -324,12 +324,12 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT const char* GenerateShapeOp::attributes_name[attributes_num] = { "output_dim_exprs", "symbol_bindings"}; -void GenerateShapeOp::Build( - pir::Builder& builder, - pir::OperationArgument& argument, - const std::vector& inputs, - const std::vector& output_dim_exprs, - const GenerateShapeOp::SymbolBindings& symbol_bindings) { +void GenerateShapeOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + const std::vector& inputs, + const std::vector& output_dim_exprs, + const SymbolBindings& symbol_bindings, + const pir::Type& output_type) { if (inputs.empty()) { VLOG(3) << "GenerateShapeOp inputs is empty"; for (const auto& attr : output_dim_exprs) { @@ -344,13 +344,7 @@ void GenerateShapeOp::Build( argument.AddAttribute( "symbol_bindings", ConvertSymbolBindingsToAttribute(builder, symbol_bindings)); - argument.AddOutputs({[&]() { - auto* ctx = pir::IrContext::Instance(); - auto type = pir::Int64Type::get(ctx); - auto dim = - ::common::make_ddim({static_cast(output_dim_exprs.size())}); - return DenseTensorType::get(ctx, type, dim); - }()}); + argument.AddOutput(output_type); ::pir::PassStopGradientsDefaultly(argument); } diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 1eddfaffd0df1..06f306a0e3623 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -168,7 +168,8 @@ class IR_API GenerateShapeOp pir::OperationArgument &argument, // NOLINT const std::vector &inputs, const std::vector &output_dim_exprs, - const SymbolBindings &symbol_bindings); + const SymbolBindings &symbol_bindings, + const pir::Type &output_type); void VerifySig() {} diff --git a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc index 63d5b519ce887..ec82d41742a70 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc @@ -232,7 +232,7 @@ class BlockDimExprsAsserter { }; std::vector input_tensors{}; std::vector output_dim_expr_attrs{}; - GenerateShapeOp::SymbolBindings symbol_bindings{}; + SymbolBindings symbol_bindings{}; bool success = MakeGenerateShapeOpAttribute(ir_ctx_, LocalDimExprs4Value, @@ -242,14 +242,13 @@ class BlockDimExprsAsserter { &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt; - auto out_shape_value = - builder_ - .Build( - input_tensors, output_dim_expr_attrs, symbol_bindings) - .out(); + auto out_type = paddle::dialect::DenseTensorType::get( + builder_.ir_context(), + pir::Int64Type::get(builder_.ir_context()), + ::common::make_ddim({dim_exprs.size()})); return builder_ .Build( - input_tensors, output_dim_expr_attrs, symbol_bindings) + input_tensors, output_dim_expr_attrs, symbol_bindings, out_type) .out(); } @@ -298,8 +297,11 @@ class BlockDimExprsAsserter { PADDLE_ENFORCE_EQ(lhs_numel, rhs_numel, ::common::errors::InvalidArgument( + "Check [%s id:%d] infer symbolic shape failed." "The numel of lhs and rhs must be equal, but " "received lhs's numel is [%d], rhs's numel is [%d]", + op->name(), + op->id(), lhs_numel, rhs_numel)); @@ -326,8 +328,8 @@ class BlockDimExprsAsserter { .out(); auto assert_op = builder_.Build( all_eq, assert_data, lhs_numel); - const std::string error_msg = "Check [" + op->name() + "_" + - std::to_string(op->id()) + + const std::string error_msg = "Check [" + op->name() + + " id:" + std::to_string(op->id()) + "] infer symbolic shape failed."; assert_op->set_attribute( paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME, diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc index 6281baeadbef2..ca422c1a593c8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc @@ -190,6 +190,15 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op, return pd_op; } +::pir::Operation* ConvertGenerateShapeOp( + ::pir::Operation* op, + ::pir::IrMapping& ir_mapping, // NOLINT + ::pir::Builder& builder) { // NOLINT + auto* new_op = op->Clone(ir_mapping, {true, true, true}); + builder.Insert(new_op); + return new_op; +} + ::pir::Operation* ConvertScaleOp(::pir::Operation* op, ::pir::IrMapping& ir_mapping, // NOLINT ::pir::PatternRewriter& rewriter) { // NOLINT @@ -404,6 +413,9 @@ REGISTER_TRANSFORM_RULES(concat_op, cinn::dialect::ConcatOp::name(), cinn::dialect::details::ConvertConcatOp); +REGISTER_TRANSFORM_RULES(generate_shape_op, + cinn::dialect::GenerateShapeOp::name(), + cinn::dialect::details::ConvertGenerateShapeOp); REGISTER_TRANSFORM_RULES(scale_op, cinn::dialect::ScaleOp::name(), cinn::dialect::details::ConvertScaleOp); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc index 17317924fb07e..ba825819b0b2e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc @@ -64,8 +64,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op, } } } + auto out_type = paddle::dialect::DenseTensorType::get( + rewriter.ir_context(), + pir::Int64Type::get(rewriter.ir_context()), + ::common::make_ddim( + {static_cast(output_dim_expr_attrs.size())})); auto cinn_generate_shape = rewriter.Build( - std::vector{input}, output_dim_expr_attrs, symbol_bindings); + std::vector{input}, + output_dim_expr_attrs, + symbol_bindings, + out_type); auto pd_reshape = rewriter.Build( op->operand_source(0), cinn_generate_shape.result(0)); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 0578c79b35a2b..207150187e23e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -313,9 +313,18 @@ std::optional GetOutOfRewrittenGenerateShapeOp( &output_dim_expr_attrs, &symbol_bindings); if (!success) return std::nullopt; + auto out_type = [&]() -> pir::Type { + if (shape.type().isa()) { + return shape.type(); + } + return paddle::dialect::DenseTensorType::get( + rewriter->ir_context(), + pir::Int64Type::get(rewriter->ir_context()), + ::common::make_ddim({output_dim_expr_attrs.size()})); + }(); return rewriter ->Build( - input_tensors, output_dim_expr_attrs, symbol_bindings) + input_tensors, output_dim_expr_attrs, symbol_bindings, out_type) .out(); } @@ -323,9 +332,9 @@ bool ReplaceShapeOpsToGenerateShape( pir::OpOperand shape_operand, pir::PatternRewriter* rewriter, pir::ShapeConstraintIRAnalysis* shape_analysis) { - if (shape_operand.source() - .defining_op() - ->isa()) { + auto* shape_def_op = shape_operand.source().defining_op(); + if (!shape_def_op || shape_def_op->num_operands() == 0) return false; + if (shape_def_op->isa()) { return false; } auto ShapeOrDataDimExprs4Value = @@ -379,6 +388,82 @@ class FuseShapeOpsIntoGenerateShapeOpPattern } }; +class FuseSingleElementShapeOpsIntoGenerateShapeOpPattern + : public pir::RewritePattern { + public: + explicit FuseSingleElementShapeOpsIntoGenerateShapeOpPattern( + pir::IrContext* context) + : pir::RewritePattern(MatchAnyOpTypeTag(), + 1 /*benefit*/, + context, + {} /*generated_names*/) {} + + bool Match(pir::Operation* op) const override { + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + if (!IsSingleElementShapeOp(op, &shape_analysis)) return false; + if (op->isa()) return false; + + // all user op's output should has no data of shape expr + pir::Value output = op->result(0); + if (output.use_empty()) return false; + for (auto iter = output.use_begin(); iter != output.use_end(); ++iter) { + auto* user = iter->owner(); + if (IsSingleElementShapeOp(user, &shape_analysis)) return false; + if (user->isa()) return false; + } + + return true; + } + + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + + auto ShapeOrDataDimExprs4Value = + [&shape_analysis]( + pir::Value value) -> const symbol::ShapeOrDataDimExprs& { + return shape_analysis.GetShapeOrDataForValue(value); + }; + std::optional opt_generated_shape = + GetOutOfRewrittenGenerateShapeOp( + op->result(0), &rewriter, ShapeOrDataDimExprs4Value); + if (!opt_generated_shape.has_value()) { + LOG(WARNING) << "Create GenerateShapeOp Failed."; + return; + } + + rewriter.ReplaceAllUsesWith(op->result(0), opt_generated_shape.value()); + + if (op->use_empty()) { + rewriter.EraseOp(op); + } + } + + private: + bool IsSingleElementShapeOp( + pir::Operation* op, + pir::ShapeConstraintIRAnalysis* shape_analysis) const { + if (op->num_operands() == 0) return false; + if (op->num_results() != 1) return false; + + pir::Value output = op->result(0); + const auto& out_shape = shape_analysis->GetShapeOrDataForValue(output); + if (!out_shape.isa()) return false; + if (!out_shape.data().has_value()) return false; + + auto dtype = + output.type().dyn_cast().dtype(); + if (!dtype.isa() && !dtype.isa()) { + return false; + } + + // Only process the op which output is a single element + return out_shape.data()->size() == 1; + } +}; + class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { public: FuseShapeOpsIntoGenerateShapeOpPass() @@ -393,6 +478,7 @@ class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { context); ps.Add>( context); + ps.Add(context); return ps; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc index 30b470d42ca2a..f2afbae3d515d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc @@ -83,8 +83,10 @@ std::optional InsertGenerateShapeOpToRunFirst( &symbol_bindings); if (success) { return builder - ->Build( - minimal_inputs, output_dim_expr_attrs, symbol_bindings) + ->Build(minimal_inputs, + output_dim_expr_attrs, + symbol_bindings, + value.type()) .out(); } return std::nullopt; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc index 8f0bab178d75c..c3daa04fc2f4e 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc @@ -233,17 +233,24 @@ std::tuple BroadcastableToCondValue( &rhs_symbol_bindings); CHECK(success); + auto out_type = paddle::dialect::DenseTensorType::get( + builder.ir_context(), + pir::Int64Type::get(builder.ir_context()), + ::common::make_ddim({1})); + auto lhs_value = builder .Build(lhs_minimal_inputs, lhs_output_dim_expr_attrs, - lhs_symbol_bindings) + lhs_symbol_bindings, + out_type) .out(); auto rhs_value = builder .Build(rhs_minimal_inputs, rhs_output_dim_expr_attrs, - rhs_symbol_bindings) + rhs_symbol_bindings, + out_type) .out(); auto const_one = builder diff --git a/test/ir/pir/cinn/symbolic/test_dyshape_group_norm.py b/test/ir/pir/cinn/symbolic/test_dyshape_group_norm.py new file mode 100644 index 0000000000000..a3e9b838eeae4 --- /dev/null +++ b/test/ir/pir/cinn/symbolic/test_dyshape_group_norm.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from os.path import dirname + +import numpy as np + +import paddle +from paddle import nn +from paddle.static import InputSpec + +sys.path.append(dirname(dirname(__file__))) + +import utils + + +class GroupNorm(nn.Layer): + def __init__(self): + super().__init__() + self.hidden_size = 768 + self.dtype = "float32" + self.weight = paddle.randn([128], dtype=self.dtype) + self.weight.stop_gradient = False + self.bias = paddle.randn([128], dtype=self.dtype) + self.bias.stop_gradient = False + + self.data_format = "NHWC" + + def forward(self, x): + return paddle.nn.functional.group_norm( + x, + num_groups=32, + epsilon=1e-6, + weight=self.weight, + bias=self.bias, + data_format=self.data_format, + ) + + +class TestGroupNorm(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.shape = [1, 128, 256, 128] + self.dtype = "float32" + self.data_format = "NHWC" + self.prepare_data() + + def prepare_data(self): + self.x = paddle.randn(self.shape, dtype=self.dtype) + self.x.stop_gradient = False + + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 2) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2}) + + def eval(self, use_cinn): + paddle.seed(2024) + net = GroupNorm() + input_spec = [ + InputSpec(shape=[None, None, None, 128], dtype='float32'), + ] + net = utils.apply_to_static(net, use_cinn, input_spec) + net.eval() + out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) + return out + + def test_eval(self): + cinn_out = self.eval(use_cinn=True) + dy_out = self.eval(use_cinn=False) + np.testing.assert_allclose( + cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6 + ) + + +if __name__ == '__main__': + unittest.main()