Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[paddle Inference]add add_group_norm_silu kernel and group_norm related pattern #64199

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ const std::vector<std::string> kPirGpuPasses{
"fused_weight_only_linear_pass",
"matmul_add_act_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"add_norm_fuse_pass",
"matmul_scale_fuse_pass",
"matmul_transpose_fuse_pass",
"transpose_flatten_concat_fuse_pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@ void RewriteByInfermeta(pir::Operation* op, common::DataLayout new_layout) {
}
}

template <>
std::vector<pir::Value> RelevantInputsImpl<AddGroupNormSiluOp>(
pir::Operation* op) {
auto concrete_op = op->dyn_cast<AddGroupNormSiluOp>();
return {concrete_op.x(), concrete_op.residual()};
}

template <>
std::vector<pir::Value> RelevantOutputsImpl<AddGroupNormSiluOp>(
pir::Operation* op) {
auto concrete_op = op->dyn_cast<AddGroupNormSiluOp>();
return {concrete_op.y(), concrete_op.residual_out()};
}

template <>
common::DataLayout PreferLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op) {
// Note(bukejiyu): add_group_norm_silu only supports NHWC layout now.
return common::DataLayout::NHWC;
}

template <>
void RewriteByLayoutImpl<AddGroupNormSiluOp>(pir::Operation* op,
common::DataLayout new_layout) {
op->set_attribute(
"data_format",
pir::StrAttribute::get(pir::IrContext::Instance(),
common::DataLayoutToString(new_layout)));

std::vector<pir::Type> new_outputs = AddGroupNormSiluOp::InferMeta(
op->operands_source(), const_cast<pir::AttributeMap*>(&op->attributes()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是可以不用const_cast, 构造一个临时的attribute map 即可? PIR下Attribute是共享的,构造成本很低的。

for (size_t i = 0; i < new_outputs.size(); ++i) {
op->result(i).set_type(new_outputs[i]);
}

for (auto value : RelevantOutputsImpl<AddGroupNormSiluOp>(op)) {
SetNewLayoutForValue(value, new_layout);
}
}

template <>
common::DataLayout PreferLayoutImpl<Conv2dOp>(pir::Operation* op) {
auto data_format_attr = op->attribute<pir::StrAttribute>("data_format");
Expand Down Expand Up @@ -78,6 +117,14 @@ common::DataLayout PreferLayoutImpl<FusedConv2dAddActOp>(pir::Operation* op) {
auto original_layout =
common::StringToDataLayout(data_format_attr.AsString());

if (op->HasAttribute(kForceBackendAttr) &&
op->attributes()
.at(kForceBackendAttr)
.dyn_cast<pir::StrAttribute>()
.AsString() == "gpu") {
return common::DataLayout::NHWC;
}
bukejiyu marked this conversation as resolved.
Show resolved Hide resolved

auto concrete_op = op->dyn_cast<FusedConv2dAddActOp>();
if (auto in = concrete_op.input()) {
if (auto in_type = in.type()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ OVERLOAD_REWRITE_BY_LAYOUT(GroupNormOp);
OVERLOAD_RELEVANT_INPUTS(GroupNormOp);
OVERLOAD_RELEVANT_OUTPUTS(GroupNormOp);

class AddGroupNormSiluOp;
OVERLOAD_REWRITE_BY_LAYOUT(AddGroupNormSiluOp);
OVERLOAD_PREFER_LAYOUT(AddGroupNormSiluOp);
OVERLOAD_RELEVANT_INPUTS(AddGroupNormSiluOp);
OVERLOAD_RELEVANT_OUTPUTS(AddGroupNormSiluOp);

class ReshapeOp;
OVERLOAD_RELEVANT_INPUTS(ReshapeOp);
OVERLOAD_RELEVANT_OUTPUTS(ReshapeOp);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/drr/src/rewrite_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ bool DrrRewritePattern::MatchFromOutputToInput(
ir_input_values[i].use_count()) {
matched = false;
VLOG(8) << drr_node->name() << " Match failed: consumers of drr intput["
<< i << "] { " << drr_node->outputs().size()
<< i << "] { " << drr_input_tensors[i]->consumers().size()
<< " } != consumers of pir intput[" << i << "] { "
<< ir_input_values[i].use_count() << " }.";
break;
Expand Down
228 changes: 207 additions & 21 deletions paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,13 @@ class RmsNormFusePattern : public paddle::drr::DrrPatternBase {
class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
private:
const bool extra_add_;
const bool trans_extra_add_;

public:
explicit AddRmsNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
AddRmsNormFusePattern(bool extra_add, bool trans_extra_add)
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}

uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
uint32_t benefit() const override { return extra_add_ ? 4 : 3; }

std::string name() const override { return "AddRmsNormFusePattern"; }

Expand Down Expand Up @@ -176,7 +178,9 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
if (extra_add_) {
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
pat.Tensor("add_out1") =
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
trans_extra_add_
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
}
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &res_rms_norm =
Expand Down Expand Up @@ -207,11 +211,13 @@ class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase {
class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
private:
const bool extra_add_;
const bool trans_extra_add_;

public:
explicit AddLayerNormFusePattern(bool extra_add) : extra_add_(extra_add) {}
AddLayerNormFusePattern(bool extra_add, bool trans_extra_add)
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}

uint32_t benefit() const override { return extra_add_ ? 2 : 1; }
uint32_t benefit() const override { return extra_add_ ? 4 : 3; }
std::string name() const override { return "AddLayerNormFusePattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
Expand All @@ -231,22 +237,20 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
if (extra_add_) {
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
pat.Tensor("add_out1") =
add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
trans_extra_add_
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
}

paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &cast_op_dtype = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> phi::DataType {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
return paddle::dialect::TransToPhiDataType(x_dtype);
return phi::DataType::FLOAT32;
});
const auto &cast_op_1 =
const auto cast_1_op =
res.Op(paddle::dialect::CastOp::name(), {{"dtype", cast_op_dtype}});
res.Tensor("casted_bias") = cast_op_1(res.Tensor("bias"));
const auto &cast_op_2 =
const auto cast_2_op =
res.Op(paddle::dialect::CastOp::name(), {{"dtype", cast_op_dtype}});
res.Tensor("casted_w") = cast_op_2(res.Tensor("w"));

const auto &fuse_layer_norm =
res.Op(paddle::dialect::FusedBiasResidualLayernormOp::name(),
{{"epsilon", pat.Attr("epsilon")},
Expand All @@ -256,14 +260,15 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
{"quant_round_type", res.Int32Attr(0)},
{"quant_max_bound", res.Float32Attr(0.0)},
{"quant_min_bound", res.Float32Attr(0.0)}});

res.Tensor("w_cast") = cast_1_op(res.Tensor("w"));
res.Tensor("bias_cast") = cast_1_op(res.Tensor("bias"));
fuse_layer_norm(
{
&res.Tensor("x"),
&res.Tensor("casted_bias"),
&res.Tensor("residual"),
&res.Tensor("casted_w"),
&res.InputNoneTensor(),
&res.Tensor("residual"),
&res.Tensor("w_cast"),
&res.Tensor("bias_cast"),
},
{&res.Tensor("layer_norm_out"),
&res.Tensor("add_out"),
Expand All @@ -272,6 +277,163 @@ class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase {
}
};

class AddGroupNormFusePattern : public paddle::drr::DrrPatternBase {
private:
const bool extra_add_;
const bool trans_extra_add_;

public:
AddGroupNormFusePattern(bool extra_add, bool trans_extra_add)
: extra_add_(extra_add), trans_extra_add_{trans_extra_add} {}

uint32_t benefit() const override { return extra_add_ ? 4 : 3; }
std::string name() const override { return "AddGroupNormFusePattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &add = pat.Op(paddle::dialect::AddOp::name());
const auto &group_norm = pat.Op(paddle::dialect::GroupNormOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")}});
pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual"));
group_norm(
{&pat.Tensor("add_out"), &pat.Tensor("scale"), &pat.Tensor("bias")},
{&pat.Tensor("group_out"),
&pat.Tensor("mean_out_0"),
&pat.Tensor("variance_out_0")});
// TODO(bukejiyu) :DRR support matching placeholder op,
// the following needs to be deleted
if (extra_add_) {
const auto &add1 = pat.Op(paddle::dialect::AddOp::name());
pat.Tensor("add_out1") =
trans_extra_add_
? add1(pat.Tensor("any_tensor"), pat.Tensor("add_out"))
: add1(pat.Tensor("add_out"), pat.Tensor("any_tensor"));
}
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
if (!x_dtype.isa<pir::Float16Type>() &&
!x_dtype.isa<pir::BFloat16Type>()) {
return false;
}
return true;
});
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &add_group_norm_silu_op =
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"activation", res.StrAttr("")}});

add_group_norm_silu_op({&res.Tensor("x"),
&res.Tensor("residual"),
&res.Tensor("scale"),
&res.Tensor("bias")},
{&res.Tensor("group_out"),
&res.Tensor("add_out"),
&res.Tensor("mean_out"),
&res.Tensor("variance_out")});
}
};

class AddGroupNormWithActPattern : public paddle::drr::DrrPatternBase {
public:
uint32_t benefit() const override { return 2; }
std::string name() const override { return "AddGroupNormWithActPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &add_group_norm_silu_op =
pat.Op(paddle::dialect::AddGroupNormSiluOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"activation", pat.Attr("activation")}});
const auto &silu = pat.Op(paddle::dialect::SiluOp::name());
add_group_norm_silu_op({&pat.Tensor("x"),
&pat.Tensor("residual"),
&pat.Tensor("scale"),
&pat.Tensor("bias")},
{&pat.Tensor("group_out"),
&pat.Tensor("add_out"),
&pat.Tensor("mean_out_0"),
&pat.Tensor("variance_out_0")});
pat.Tensor("silu_out") = silu(pat.Tensor("group_out"));
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
if (!x_dtype.isa<pir::Float16Type>() &&
!x_dtype.isa<pir::BFloat16Type>()) {
return false;
}
auto activation = match_ctx.Attr<std::string>("activation");
if (activation != "") {
return false;
}
return true;
});
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &res_add_group_norm_silu_op =
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"activation", res.StrAttr("silu")}});
res_add_group_norm_silu_op({&res.Tensor("x"),
&res.Tensor("residual"),
&res.Tensor("scale"),
&res.Tensor("bias")},
{&res.Tensor("silu_out"),
&res.Tensor("add_out"),
&res.Tensor("mean_out"),
&res.Tensor("variance_out")});
}
};

class GroupNormWithActPattern : public paddle::drr::DrrPatternBase {
public:
uint32_t benefit() const override { return 1; }
std::string name() const override { return "GroupNormWithActPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &group_norm = pat.Op(paddle::dialect::GroupNormOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")}});
const auto &silu = pat.Op(paddle::dialect::SiluOp::name());
group_norm({&pat.Tensor("x"), &pat.Tensor("scale"), &pat.Tensor("bias")},
{&pat.Tensor("group_out"),
&pat.Tensor("mean_out_0"),
&pat.Tensor("variance_out_0")});
pat.Tensor("silu_out") = silu(pat.Tensor("group_out"));
pat.AddConstraint([this](const paddle::drr::MatchContext &match_ctx) {
auto x_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("x"));
if (!x_dtype.isa<pir::Float16Type>() &&
!x_dtype.isa<pir::BFloat16Type>()) {
return false;
}
return true;
});
paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &add_group_norm_silu_op =
res.Op(paddle::dialect::AddGroupNormSiluOp::name(),
{{"epsilon", pat.Attr("epsilon")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
{"activation", res.StrAttr("silu")}});
add_group_norm_silu_op({&res.Tensor("x"),
&res.InputNoneTensor(),
&res.Tensor("scale"),
&res.Tensor("bias")},
{&res.Tensor("silu_out"),
&res.OutputNoneTensor(),
&res.Tensor("mean_out"),
&res.Tensor("variance_out")});
}
};

class AddNormFusePass : public pir::PatternRewritePass {
public:
AddNormFusePass() : pir::PatternRewritePass("add_norm_fuse_pass", 2) {}
Expand All @@ -290,13 +452,37 @@ class AddNormFusePass : public pir::PatternRewritePass {
// x--------
// add-rms_norm ---> rms_norm
// residual-
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add));
ps.Add(paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add));
ps.Add(
paddle::drr::Create<AddRmsNormFusePattern>(context, !extra_add, false));
ps.Add(
paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, true));
ps.Add(
paddle::drr::Create<AddRmsNormFusePattern>(context, extra_add, false));

// x--------
// add-layer_norm ----> fused_bias_residual_layernorm
// residual-
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, !extra_add));
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add));
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(
context, !extra_add, false));
ps.Add(
paddle::drr::Create<AddLayerNormFusePattern>(context, extra_add, true));
ps.Add(paddle::drr::Create<AddLayerNormFusePattern>(
context, extra_add, false));

// x--------
// add-group_norm ----> add_group_norm_silu
// residual-
ps.Add(paddle::drr::Create<AddGroupNormFusePattern>(
context, !extra_add, true));
ps.Add(
paddle::drr::Create<AddGroupNormFusePattern>(context, extra_add, true));
ps.Add(paddle::drr::Create<AddGroupNormFusePattern>(
context, extra_add, false));

// add_group_norm_silu-silu --->add_group_norm_silu
ps.Add(paddle::drr::Create<AddGroupNormWithActPattern>(context));
// group-silu->add_group_norm_silu
ps.Add(paddle::drr::Create<GroupNormWithActPattern>(context));
bukejiyu marked this conversation as resolved.
Show resolved Hide resolved
return ps;
}
};
Expand Down
Loading