Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]Support more shape ops fuse to generate shape op #64216

Merged
merged 19 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,12 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT
const char* GenerateShapeOp::attributes_name[attributes_num] = {
"output_dim_exprs", "symbol_bindings"};

void GenerateShapeOp::Build(
pir::Builder& builder,
pir::OperationArgument& argument,
const std::vector<pir::Value>& inputs,
const std::vector<pir::Attribute>& output_dim_exprs,
const GenerateShapeOp::SymbolBindings& symbol_bindings) {
void GenerateShapeOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
const std::vector<pir::Value>& inputs,
const std::vector<pir::Attribute>& output_dim_exprs,
const SymbolBindings& symbol_bindings,
const pir::Type& output_type) {
if (inputs.empty()) {
VLOG(3) << "GenerateShapeOp inputs is empty";
for (const auto& attr : output_dim_exprs) {
Expand All @@ -344,13 +344,7 @@ void GenerateShapeOp::Build(
argument.AddAttribute(
"symbol_bindings",
ConvertSymbolBindingsToAttribute(builder, symbol_bindings));
argument.AddOutputs({[&]() {
auto* ctx = pir::IrContext::Instance();
auto type = pir::Int64Type::get(ctx);
auto dim =
::common::make_ddim({static_cast<int64_t>(output_dim_exprs.size())});
return DenseTensorType::get(ctx, type, dim);
}()});
argument.AddOutput(output_type);
::pir::PassStopGradientsDefaultly(argument);
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ class IR_API GenerateShapeOp
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Attribute> &output_dim_exprs,
const SymbolBindings &symbol_bindings);
const SymbolBindings &symbol_bindings,
const pir::Type &output_type);

void VerifySig() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class BlockDimExprsAsserter {
};
std::vector<pir::Value> input_tensors{};
std::vector<pir::Attribute> output_dim_expr_attrs{};
GenerateShapeOp::SymbolBindings symbol_bindings{};
SymbolBindings symbol_bindings{};
bool success =
MakeGenerateShapeOpAttribute(ir_ctx_,
LocalDimExprs4Value,
Expand All @@ -242,14 +242,13 @@ class BlockDimExprsAsserter {
&output_dim_expr_attrs,
&symbol_bindings);
if (!success) return std::nullopt;
auto out_shape_value =
builder_
.Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings)
.out();
auto out_type = paddle::dialect::DenseTensorType::get(
builder_.ir_context(),
pir::Int64Type::get(builder_.ir_context()),
::common::make_ddim({dim_exprs.size()}));
return builder_
.Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings)
input_tensors, output_dim_expr_attrs, symbol_bindings, out_type)
.out();
}

Expand Down Expand Up @@ -298,8 +297,11 @@ class BlockDimExprsAsserter {
PADDLE_ENFORCE_EQ(lhs_numel,
rhs_numel,
::common::errors::InvalidArgument(
"Check [%s id:%d] infer symbolic shape failed."
"The numel of lhs and rhs must be equal, but "
"received lhs's numel is [%d], rhs's numel is [%d]",
op->name(),
op->id(),
lhs_numel,
rhs_numel));

Expand All @@ -326,8 +328,8 @@ class BlockDimExprsAsserter {
.out();
auto assert_op = builder_.Build<paddle::dialect::AssertOp>(
all_eq, assert_data, lhs_numel);
const std::string error_msg = "Check [" + op->name() + "_" +
std::to_string(op->id()) +
const std::string error_msg = "Check [" + op->name() +
" id:" + std::to_string(op->id()) +
"] infer symbolic shape failed.";
assert_op->set_attribute(
paddle::dialect::AssertOp::ERROR_INFO_ATTR_NAME,
Expand Down
12 changes: 12 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ ::pir::Operation* ConvertConcatOp(::pir::Operation* op,
return pd_op;
}

::pir::Operation* ConvertGenerateShapeOp(
::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::Builder& builder) { // NOLINT
auto* new_op = op->Clone(ir_mapping, {true, true, true});
builder.Insert(new_op);
return new_op;
}

::pir::Operation* ConvertScaleOp(::pir::Operation* op,
::pir::IrMapping& ir_mapping, // NOLINT
::pir::PatternRewriter& rewriter) { // NOLINT
Expand Down Expand Up @@ -404,6 +413,9 @@ REGISTER_TRANSFORM_RULES(concat_op,
cinn::dialect::ConcatOp::name(),
cinn::dialect::details::ConvertConcatOp);

REGISTER_TRANSFORM_RULES(generate_shape_op,
cinn::dialect::GenerateShapeOp::name(),
cinn::dialect::details::ConvertGenerateShapeOp);
REGISTER_TRANSFORM_RULES(scale_op,
cinn::dialect::ScaleOp::name(),
cinn::dialect::details::ConvertScaleOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
}
}
}
auto out_type = paddle::dialect::DenseTensorType::get(
rewriter.ir_context(),
pir::Int64Type::get(rewriter.ir_context()),
::common::make_ddim(
{static_cast<int64_t>(output_dim_expr_attrs.size())}));
auto cinn_generate_shape = rewriter.Build<cinn::dialect::GenerateShapeOp>(
std::vector<pir::Value>{input}, output_dim_expr_attrs, symbol_bindings);
std::vector<pir::Value>{input},
output_dim_expr_attrs,
symbol_bindings,
out_type);
auto pd_reshape = rewriter.Build<paddle::dialect::ReshapeOp>(
op->operand_source(0), cinn_generate_shape.result(0));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,28 @@ std::optional<pir::Value> GetOutOfRewrittenGenerateShapeOp(
&output_dim_expr_attrs,
&symbol_bindings);
if (!success) return std::nullopt;
auto out_type = [&]() -> pir::Type {
if (shape.type().isa<paddle::dialect::DenseTensorType>()) {
return shape.type();
}
return paddle::dialect::DenseTensorType::get(
rewriter->ir_context(),
pir::Int64Type::get(rewriter->ir_context()),
::common::make_ddim({output_dim_expr_attrs.size()}));
}();
return rewriter
->Build<cinn::dialect::GenerateShapeOp>(
input_tensors, output_dim_expr_attrs, symbol_bindings)
input_tensors, output_dim_expr_attrs, symbol_bindings, out_type)
.out();
}

bool ReplaceShapeOpsToGenerateShape(
pir::OpOperand shape_operand,
pir::PatternRewriter* rewriter,
pir::ShapeConstraintIRAnalysis* shape_analysis) {
if (shape_operand.source()
.defining_op()
->isa<cinn::dialect::GenerateShapeOp>()) {
auto* shape_def_op = shape_operand.source().defining_op();
if (!shape_def_op || shape_def_op->num_operands() == 0) return false;
if (shape_def_op->isa<cinn::dialect::GenerateShapeOp>()) {
return false;
}
auto ShapeOrDataDimExprs4Value =
Expand Down Expand Up @@ -379,6 +388,82 @@ class FuseShapeOpsIntoGenerateShapeOpPattern
}
};

class FuseSingleElementShapeOpsIntoGenerateShapeOpPattern
: public pir::RewritePattern {
public:
explicit FuseSingleElementShapeOpsIntoGenerateShapeOpPattern(
pir::IrContext* context)
: pir::RewritePattern(MatchAnyOpTypeTag(),
1 /*benefit*/,
context,
{} /*generated_names*/) {}

bool Match(pir::Operation* op) const override {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
if (!IsSingleElementShapeOp(op, &shape_analysis)) return false;
if (op->isa<cinn::dialect::GenerateShapeOp>()) return false;

// all user op's output should has no data of shape expr
pir::Value output = op->result(0);
if (output.use_empty()) return false;
for (auto iter = output.use_begin(); iter != output.use_end(); ++iter) {
auto* user = iter->owner();
if (IsSingleElementShapeOp(user, &shape_analysis)) return false;
if (user->isa<cinn::dialect::GenerateShapeOp>()) return false;
}

return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

auto ShapeOrDataDimExprs4Value =
[&shape_analysis](
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis.GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewrittenGenerateShapeOp(
op->result(0), &rewriter, ShapeOrDataDimExprs4Value);
if (!opt_generated_shape.has_value()) {
LOG(WARNING) << "Create GenerateShapeOp Failed.";
return;
}

rewriter.ReplaceAllUsesWith(op->result(0), opt_generated_shape.value());

if (op->use_empty()) {
rewriter.EraseOp(op);
}
}

private:
bool IsSingleElementShapeOp(
pir::Operation* op,
pir::ShapeConstraintIRAnalysis* shape_analysis) const {
if (op->num_operands() == 0) return false;
if (op->num_results() != 1) return false;

pir::Value output = op->result(0);
const auto& out_shape = shape_analysis->GetShapeOrDataForValue(output);
if (!out_shape.isa<symbol::TensorShapeOrDataDimExprs>()) return false;
if (!out_shape.data().has_value()) return false;

auto dtype =
output.type().dyn_cast<paddle::dialect::DenseTensorType>().dtype();
if (!dtype.isa<pir::Int32Type>() && !dtype.isa<pir::Int64Type>()) {
return false;
}

// Only process the op which output is a single element
return out_shape.data()->size() == 1;
}
};

class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
public:
FuseShapeOpsIntoGenerateShapeOpPass()
Expand All @@ -393,6 +478,7 @@ class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
context);
ps.Add<FuseShapeOpsIntoGenerateShapeOpPattern<paddle::dialect::SliceOp>>(
context);
ps.Add<FuseSingleElementShapeOpsIntoGenerateShapeOpPattern>(context);
return ps;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ std::optional<pir::Value> InsertGenerateShapeOpToRunFirst(
&symbol_bindings);
if (success) {
return builder
->Build<cinn::dialect::GenerateShapeOp>(
minimal_inputs, output_dim_expr_attrs, symbol_bindings)
->Build<cinn::dialect::GenerateShapeOp>(minimal_inputs,
output_dim_expr_attrs,
symbol_bindings,
value.type())
.out();
}
return std::nullopt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,24 @@ std::tuple<pir::Value, pir::Value, pir::Value> BroadcastableToCondValue(
&rhs_symbol_bindings);
CHECK(success);

auto out_type = paddle::dialect::DenseTensorType::get(
builder.ir_context(),
pir::Int64Type::get(builder.ir_context()),
::common::make_ddim({1}));

auto lhs_value =
builder
.Build<cinn::dialect::GenerateShapeOp>(lhs_minimal_inputs,
lhs_output_dim_expr_attrs,
lhs_symbol_bindings)
lhs_symbol_bindings,
out_type)
.out();
auto rhs_value =
builder
.Build<cinn::dialect::GenerateShapeOp>(rhs_minimal_inputs,
rhs_output_dim_expr_attrs,
rhs_symbol_bindings)
rhs_symbol_bindings,
out_type)
.out();

auto const_one = builder
Expand Down
91 changes: 91 additions & 0 deletions test/ir/pir/cinn/symbolic/test_dyshape_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
from os.path import dirname

import numpy as np

import paddle
from paddle import nn
from paddle.static import InputSpec

sys.path.append(dirname(dirname(__file__)))

import utils


class GroupNorm(nn.Layer):
def __init__(self):
super().__init__()
self.hidden_size = 768
self.dtype = "float32"
self.weight = paddle.randn([128], dtype=self.dtype)
self.weight.stop_gradient = False
self.bias = paddle.randn([128], dtype=self.dtype)
self.bias.stop_gradient = False

self.data_format = "NHWC"

def forward(self, x):
return paddle.nn.functional.group_norm(
x,
num_groups=32,
epsilon=1e-6,
weight=self.weight,
bias=self.bias,
data_format=self.data_format,
)


class TestGroupNorm(unittest.TestCase):
def setUp(self):
paddle.seed(2024)
self.shape = [1, 128, 256, 128]
self.dtype = "float32"
self.data_format = "NHWC"
self.prepare_data()

def prepare_data(self):
self.x = paddle.randn(self.shape, dtype=self.dtype)
self.x.stop_gradient = False

def check_jit_kernel_info(self, static_fn):
utils.check_jit_kernel_number(static_fn, 2)
utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 2})

def eval(self, use_cinn):
paddle.seed(2024)
net = GroupNorm()
input_spec = [
InputSpec(shape=[None, None, None, 128], dtype='float32'),
]
net = utils.apply_to_static(net, use_cinn, input_spec)
net.eval()
out = net(self.x)
if use_cinn:
self.check_jit_kernel_info(net.forward)
return out

def test_eval(self):
cinn_out = self.eval(use_cinn=True)
dy_out = self.eval(use_cinn=False)
np.testing.assert_allclose(
cinn_out.numpy(), dy_out.numpy(), atol=1e-6, rtol=1e-6
)


if __name__ == '__main__':
unittest.main()