Skip to content

Commit

Permalink
Revert "[CINN]Support more shape ops fuse to generate shape op (#64216)"
Browse files Browse the repository at this point in the history
This reverts commit af93c4f.
  • Loading branch information
zyfncg committed Jun 3, 2024
1 parent 7a8eda5 commit 72a893f
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 235 deletions.
20 changes: 13 additions & 7 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT
const char* GenerateShapeOp::attributes_name[attributes_num] = {
"output_dim_exprs", "symbol_bindings"};

void GenerateShapeOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
const std::vector<pir::Value>& inputs,
const std::vector<pir::Attribute>& output_dim_exprs,
const SymbolBindings& symbol_bindings,
const pir::Type& output_type) {
void GenerateShapeOp::Build(
pir::Builder& builder,
pir::OperationArgument& argument,
const std::vector<pir::Value>& inputs,
const std::vector<pir::Attribute>& output_dim_exprs,
const GenerateShapeOp::SymbolBindings& symbol_bindings) {
if (inputs.empty()) {
VLOG(3) << "GenerateShapeOp inputs is empty";
for (const auto& attr : output_dim_exprs) {
Expand All @@ -344,7 +344,13 @@ void GenerateShapeOp::Build(pir::Builder& builder,
argument.AddAttribute(
"symbol_bindings",
ConvertSymbolBindingsToAttribute(builder, symbol_bindings));
argument.AddOutput(output_type);
argument.AddOutputs({[&]() {
auto* ctx = pir::IrContext::Instance();
auto type = pir::Int64Type::get(ctx);
auto dim =
::common::make_ddim({static_cast<int64_t>(output_dim_exprs.size())});
return DenseTensorType::get(ctx, type, dim);
}()});
::pir::PassStopGradientsDefaultly(argument);
}

Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ class IR_API GenerateShapeOp
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Attribute> &output_dim_exprs,
const SymbolBindings &symbol_bindings,
const pir::Type &output_type);
const SymbolBindings &symbol_bindings);

void VerifySig() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class BlockDimExprsAsserter {
};
std::vector<pir::Value> input_tensors{};
std::vector<pir::Attribute> output_dim_expr_attrs{};
SymbolBindings symbol_bindings{};
GenerateShapeOp::SymbolBindings symbol_bindings{};
bool success =
MakeGenerateShapeOpAttribute(ir_ctx_,
LocalDimExprs4Value,
Expand All @@ -242,13 +242,14 @@ class BlockDimExprsAsserter {
&output_dim_expr_attrs,
&symbol_bindings);
if (!success) return std::nullopt;
auto out_type = paddle::dialect::DenseTensorType::get(
builder_.ir_context(),
pir::Int64Type::get(builder_.ir_context()),
::common::make_ddim({dim_exprs.size()}));
auto out_shape_value =
builder_
.Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
return builder_
.Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings, out_type)
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
}

Expand Down Expand Up @@ -297,11 +298,8 @@ class BlockDimExprsAsserter {
PADDLE_ENFORCE_EQ(lhs_numel,
rhs_numel,
::common::errors::InvalidArgument(
"Check [%s id:%d] infer symbolic shape failed."
"The numel of lhs and rhs must be equal, but "
"received lhs's numel is [%d], rhs's numel is [%d]",
op->name(),
op->id(),
lhs_numel,
rhs_numel));

Expand All @@ -328,8 +326,8 @@ class BlockDimExprsAsserter {
.out();
auto assert_op = builder_.Build<paddle::dialect::AssertOp>(
all_eq, assert_data, lhs_numel);
const std::string error_msg = "Check [" + op->name() +
" id:" + std::to_string(op->id()) +
const std::string error_msg = "Check [" + op->name() + "_" +
std::to_string(op->id()) +
"] infer symbolic shape failed.";
assert_op->set_attribute(
paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME,
Expand Down
12 changes: 0 additions & 12 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,6 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op,
return pd_op;
}

::pir::Operation* ConvertGenerateShapeOp(
::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
auto* new_op = op->Clone(ir_mapping, {true, true, true});
builder.Insert(new_op);
return new_op;
}

::pir::Operation* ConvertScaleOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::PatternRewriter& rewriter) { // NOLINT
Expand Down Expand Up @@ -413,9 +404,6 @@ REGISTER_TRANSFORM_RULES(concat_op,
cinn::dialect::ConcatOp::name(),
cinn::dialect::details::ConvertConcatOp);

REGISTER_TRANSFORM_RULES(generate_shape_op,
cinn::dialect::GenerateShapeOp::name(),
cinn::dialect::details::ConvertGenerateShapeOp);
REGISTER_TRANSFORM_RULES(scale_op,
cinn::dialect::ScaleOp::name(),
cinn::dialect::details::ConvertScaleOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,8 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
}
}
}
auto out_type = paddle::dialect::DenseTensorType::get(
rewriter.ir_context(),
pir::Int64Type::get(rewriter.ir_context()),
::common::make_ddim(
{static_cast<int64_t>(output_dim_expr_attrs.size())}));
auto cinn_generate_shape = rewriter.Build<cinn::dialect::GenerateShapeOp>(
std::vector<pir::Value>{input},
output_dim_expr_attrs,
symbol_bindings,
out_type);
std::vector<pir::Value>{input}, output_dim_expr_attrs, symbol_bindings);
auto pd_reshape = rewriter.Build<paddle::dialect::ReshapeOp>(
op->operand_source(0), cinn_generate_shape.result(0));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,28 +313,19 @@ std::optional<pir::Value> GetOutOfRewrittenGenerateShapeOp(
&output_dim_expr_attrs,
&symbol_bindings);
if (!success) return std::nullopt;
auto out_type = [&]() -> pir::Type {
if (shape.type().isa<paddle::dialect::DenseTensorType>()) {
return shape.type();
}
return paddle::dialect::DenseTensorType::get(
rewriter->ir_context(),
pir::Int64Type::get(rewriter->ir_context()),
::common::make_ddim({output_dim_expr_attrs.size()}));
}();
return rewriter
->Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings, out_type)
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
}

bool ReplaceShapeOpsToGenerateShape(
pir::OpOperand shape_operand,
pir::PatternRewriter* rewriter,
pir::ShapeConstraintIRAnalysis* shape_analysis) {
auto* shape_def_op = shape_operand.source().defining_op();
if (!shape_def_op || shape_def_op->num_operands() == 0) return false;
if (shape_def_op->isa<cinn::dialect::GenerateShapeOp>()) {
if (shape_operand.source()
.defining_op()
->isa<cinn::dialect::GenerateShapeOp>()) {
return false;
}
auto ShapeOrDataDimExprs4Value =
Expand Down Expand Up @@ -388,82 +379,6 @@ class FuseShapeOpsIntoGenerateShapeOpPattern
}
};

class FuseSingleElementShapeOpsIntoGenerateShapeOpPattern
: public pir::RewritePattern {
public:
explicit FuseSingleElementShapeOpsIntoGenerateShapeOpPattern(
pir::IrContext* context)
: pir::RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
context,
{} /*generated_names*/) {}

bool Match(pir::Operation* op) const override {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
if (!IsSingleElementShapeOp(op, &shape_analysis)) return false;
if (op->isa<cinn::dialect::GenerateShapeOp>()) return false;

// all user op's output should has no data of shape expr
pir::Value output = op->result(0);
if (output.use_empty()) return false;
for (auto iter = output.use_begin(); iter != output.use_end(); ++iter) {
auto* user = iter->owner();
if (IsSingleElementShapeOp(user, &shape_analysis)) return false;
if (user->isa<cinn::dialect::GenerateShapeOp>()) return false;
}

return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

auto ShapeOrDataDimExprs4Value =
[&shape_analysis](
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis.GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewrittenGenerateShapeOp(
op->result(0), &rewriter, ShapeOrDataDimExprs4Value);
if (!opt_generated_shape.has_value()) {
LOG(WARNING) << "Create GenerateShapeOp Failed.";
return;
}

rewriter.ReplaceAllUsesWith(op->result(0), opt_generated_shape.value());

if (op->use_empty()) {
rewriter.EraseOp(op);
}
}

private:
bool IsSingleElementShapeOp(
pir::Operation* op,
pir::ShapeConstraintIRAnalysis* shape_analysis) const {
if (op->num_operands() == 0) return false;
if (op->num_results() != 1) return false;

pir::Value output = op->result(0);
const auto& out_shape = shape_analysis->GetShapeOrDataForValue(output);
if (!out_shape.isa<symbol::TensorShapeOrDataDimExprs>()) return false;
if (!out_shape.data().has_value()) return false;

auto dtype =
output.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
if (!dtype.isa<pir::Int32Type>() && !dtype.isa<pir::Int64Type>()) {
return false;
}

// Only process the op which output is a single element
return out_shape.data()->size() == 1;
}
};

class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
public:
FuseShapeOpsIntoGenerateShapeOpPass()
Expand All @@ -478,7 +393,6 @@ class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
context);
ps.Add<FuseShapeOpsIntoGenerateShapeOpPattern<paddle::dialect::SliceOp>>(
context);
ps.Add<FuseSingleElementShapeOpsIntoGenerateShapeOpPattern>(context);
return ps;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ std::optional<pir::Value> InsertGenerateShapeOpToRunFirst(
&symbol_bindings);
if (success) {
return builder
->Build<cinn::dialect::GenerateShapeOp>(minimal_inputs,
output_dim_expr_attrs,
symbol_bindings,
value.type())
->Build<cinn::dialect::GenerateShapeOp>(
minimal_inputs, output_dim_expr_attrs, symbol_bindings)
.out();
}
return std::nullopt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,24 +233,17 @@ std::tuple<pir::Value, pir::Value, pir::Value> BroadcastableToCondValue(
&rhs_symbol_bindings);
CHECK(success);

auto out_type = paddle::dialect::DenseTensorType::get(
builder.ir_context(),
pir::Int64Type::get(builder.ir_context()),
::common::make_ddim({1}));

auto lhs_value =
builder
.Build<cinn::dialect::GenerateShapeOp>(lhs_minimal_inputs,
lhs_output_dim_expr_attrs,
lhs_symbol_bindings,
out_type)
lhs_symbol_bindings)
.out();
auto rhs_value =
builder
.Build<cinn::dialect::GenerateShapeOp>(rhs_minimal_inputs,
rhs_output_dim_expr_attrs,
rhs_symbol_bindings,
out_type)
rhs_symbol_bindings)
.out();

auto const_one = builder
Expand Down
91 changes: 0 additions & 91 deletions test/ir/pir/cinn/symbolic/test_dyshape_group_norm.py

This file was deleted.

0 comments on commit 72a893f

Please sign in to comment.