Skip to content

Commit

Permalink
Add an elementwise + activation fusion pass. (#36541)
Browse files Browse the repository at this point in the history
* Add elementwise add and activation fuse pass

* Fix copy ellision

* More flexible pattern detector

* More flexible fusion pass

* Update lists for pass

* Add support for Pow operator

* Add support for more activation types

* Style

* Rename fusion pass

* First version of tests

* Dirty version of pass

* Polished version

* Update pbtxt

* Style

* Update names

* Style

* Use PADDLE_ENFORCE_EQ

* Save error message to variable

* WO for error checks

* CR

* Static style check

* Add missing 'activation_scale' attribute

* Add relu6 and sigmoid activations

* Style

* Fix fuse list formating

* Sync filenames for fuse pass files

* Fix cmake after move

* Fix registration

* Fix pass name in tests

* Add missing activations to checker

* WIPS

* Working mul op

* Working sub

* Working Add

* Remove pten includes

* Remove some forward declarations

* Remove Includes

* Fixes

* Remove default kernels

* Add check if post_ops attributes are avaliable

* Style

* Code adjustment

* Register default kernels

* We have year 2022 not 2021...

Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Sylwester Fraczek <sylwester.fraczek@intel.com>

* Fast review fixes

Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Sylwester Fraczek <sylwester.fraczek@intel.com>

* Review Fix

* Rename one_dnn -> onednn

* Style after review

* Fast and dirty fix for quantization

* Update tests

* Style

* Fix mkldnn_quantizer config

* Add Joanna's suggestion.

* Check if operator is explicitly disables on OneDNN

* Try to use unregistered attributes

* Style

* Test new framework

* FXI

* FXII

* Update test

* Style

Co-authored-by: jakpiase <jakpia21@gmail.com>
Co-authored-by: Sylwester Fraczek <sylwester.fraczek@intel.com>
  • Loading branch information
3 people committed Mar 14, 2022
1 parent 1f7b251 commit 3f21916
Show file tree
Hide file tree
Showing 10 changed files with 702 additions and 10 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Expand Up @@ -126,6 +126,7 @@ if(WITH_MKLDNN)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Expand Up @@ -918,6 +918,36 @@ PDNode *patterns::ConvActivation::operator()(
return activation_out_var;
}

PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a,
const std::string &elementwise_type, const std::string &activation_type) {
// Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
auto *elementwise_b = pattern->NewNode(elementwise_b_repr())
->AsInput()
->assert_is_op_input(elementwise_type, "Y");
// intermediate variable, will be removed in the IR after fuse.
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(elementwise_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);

elementwise_op->LinksFrom({elementwise_a, elementwise_b})
.LinksTo({elementwise_out_var});
activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var});
return activation_out_var;
}

PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Expand Up @@ -487,6 +487,28 @@ struct ConvActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};

// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {}

PDNode* operator()(PDNode* elementwise_a, const std::string& elementwise_type,
const std::string& activation_type);

// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_b);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(activation_out);
};

// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
Expand Down
145 changes: 145 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.cc
@@ -0,0 +1,145 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

using string::PrettyLogDetail;

void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"relu", "tanh", "leaky_relu", "swish", "hardswish", "sqrt",
"abs", "clip", "gelu", "relu6", "sigmoid"};
std::vector<std::string> elt_types = {"elementwise_add", "elementwise_sub",
"elementwise_mul"};

for (const auto &elt_type : elt_types)
for (const auto &act_type : act_types) {
std::unordered_map<std::string, std::string> attr_map;

if (act_type == "swish")
attr_map.emplace("beta", "activation_alpha");
else if (act_type == "relu6")
attr_map.emplace("threshold", "activation_alpha");
else if (act_type == "clip") {
attr_map.emplace("min", "activation_alpha");
attr_map.emplace("max", "activation_beta");
} else {
attr_map.emplace("alpha", "activation_alpha");
attr_map.emplace("beta", "activation_beta");
}
FuseElementwiseAct(graph, elt_type, act_type, attr_map);
}
}

void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
Graph *graph, const std::string &elt_type, const std::string &act_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elementwise_act", graph);

GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern()
->NewNode(elt_type + "_act/elementwise_input")
->AsInput()
->assert_is_op_input(elt_type, "X");
patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(),
elt_type + "_act");
elementwise_act_pattern(elementwise_input, elt_type, act_type);

int found_elementwise_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse " << elt_type << " with activation op.";
// Elementwise output
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
elementwise_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise,
elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern);

auto *elementwise_op = elementwise->Op();

if (elementwise_op->HasAttr("use_mkldnn")) {
const std::string wo_elt_type =
"The " + elt_type; // Workaround for PP error message checking.
PADDLE_ENFORCE_EQ(
BOOST_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")), true,
platform::errors::PreconditionNotMet(
wo_elt_type + "+Act fusion may happen only when oneDNN library "
"is used."));
}

auto *activation_op = activation->Op();
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
activation_op->GetAttr(attr.first));
}
}

if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
BOOST_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("activation_type", std::string("gelu_tanh"));
else
elementwise_op->SetAttr("activation_type", act_type);

elementwise_op->SetOutput("Out", {activation_out->Name()});

IR_OP_VAR_LINK(elementwise, activation_out);
GraphSafeRemoveNodes(g, {activation, elementwise_out});
found_elementwise_activation_count++;
};

gpd(graph, handler);
AddStatis(found_elementwise_activation_count);
PrettyLogDetail("--- fused %d %s with %s activation",
found_elementwise_activation_count, elt_type, act_type);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(elt_act_mkldnn_fuse_pass,
paddle::framework::ir::ElementwiseActivationOneDNNPass);
REGISTER_PASS_CAPABILITY(elt_act_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.LE("elementwise_sub", 1)
.LE("elementwise_mul", 1)
.LE("relu", 0)
.LE("tanh", 0)
.LE("leaky_relu", 1)
.LE("swish", 0)
.LE("hard_swish", 0)
.LE("sqrt", 0)
.LE("abs", 0)
.LE("clip", 1)
.LE("gelu", 0)
.LE("relu6", 0)
.LE("sigmoid", 0));
44 changes: 44 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace ir {

/*
* \brief Fuse the Elementwise and activation operators into single
* OneDNN's Elementwise with post-op.
*/
class ElementwiseActivationOneDNNPass : public FusePassBase {
public:
virtual ~ElementwiseActivationOneDNNPass() {}

protected:
void ApplyImpl(Graph *graph) const override;

void FuseElementwiseAct(
Graph *graph, const std::string &elt_types, const std::string &act_types,
const std::unordered_map<std::string, std::string> &attr_map) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Expand Up @@ -262,6 +262,7 @@ void CpuPassStrategy::EnableMKLDNN() {
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", //
"elt_act_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
Expand Down
45 changes: 42 additions & 3 deletions paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h
Expand Up @@ -32,6 +32,45 @@ using dnnl::stream;

template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
private:
dnnl::post_ops get_post_ops(const framework::ExecutionContext& ctx) const {
dnnl::post_ops post_operations;
if (ctx.HasAttr("activation_type")) {
const float scale = ctx.HasAttr("activation_scale")
? ctx.Attr<float>("activation_scale")
: 1.0f;
const float alpha = ctx.HasAttr("activation_alpha")
? ctx.Attr<float>("activation_alpha")
: 0.0f;
const float beta = ctx.HasAttr("activation_beta")
? ctx.Attr<float>("activation_beta")
: 0.0f;

static std::unordered_map<std::string, dnnl::algorithm> algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"leaky_relu", dnnl::algorithm::eltwise_relu},
{"swish", dnnl::algorithm::eltwise_swish},
{"hardswish", dnnl::algorithm::eltwise_hardswish},
{"sqrt", dnnl::algorithm::eltwise_sqrt},
{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
{"gelu", dnnl::algorithm::eltwise_gelu_erf},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"relu6", dnnl::algorithm::eltwise_bounded_relu},
{"sigmoid", dnnl::algorithm::eltwise_logistic}};

const auto& activation_type =
algo_map.find(ctx.Attr<std::string>("activation_type"));

if (activation_type != algo_map.end()) {
post_operations.append_eltwise(scale, activation_type->second, alpha,
beta);
}
}
return post_operations;
}

public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
Expand All @@ -47,9 +86,9 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis");

platform::BinaryMKLDNNHandler<T> handler(BINARY_OP, axis, mkldnn_engine,
ctx.GetPlace(), x, y, z, scale_x,
scale_y, scale_o);
platform::BinaryMKLDNNHandler<T> handler(
BINARY_OP, axis, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x,
scale_y, scale_o, get_post_ops(ctx));

const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
Expand Down
14 changes: 7 additions & 7 deletions paddle/fluid/platform/mkldnn_reuse.h
Expand Up @@ -618,7 +618,7 @@ class BinaryMKLDNNHandler
const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z,
const dnnl::post_ops& post_ops = dnnl::post_ops())
const dnnl::post_ops& post_ops = dnnl::post_ops{})
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
Expand Down Expand Up @@ -676,8 +676,8 @@ class BinaryMKLDNNHandler
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);

auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z);
attributes.set_post_ops(post_ops);
auto attributes =
CreateAttributes(algo, scale_x, scale_y, scale_z, post_ops);

this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
dst_md);
Expand All @@ -690,10 +690,9 @@ class BinaryMKLDNNHandler
}

private:
static inline dnnl::primitive_attr CreateAttributes(dnnl::algorithm op,
float scale_x,
float scale_y,
float scale_z) {
static inline dnnl::primitive_attr CreateAttributes(
dnnl::algorithm op, float scale_x, float scale_y, float scale_z,
dnnl::post_ops post_ops = dnnl::post_ops{}) {
// Scales set in attributes for inputs contibute to the output equation
// in the following way (assuming no broadcasting takes place):
// output_i = scale_0 * x_i <+ or *> scale_1 * y_i;
Expand All @@ -718,6 +717,7 @@ class BinaryMKLDNNHandler
{scale_0});
attributes.set_scales(/* input_y_id = */ DNNL_ARG_SRC_1, /* mask = */ 0,
{scale_1});
if (post_ops.len() > 0) attributes.set_post_ops(post_ops);
return attributes;
}
};
Expand Down

0 comments on commit 3f21916

Please sign in to comment.