Skip to content

Commit

Permalink
support the fusion of batch_norm and relu for AMP. test=release/1.7 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wzzju authored and sneaxiy committed Jan 14, 2020
1 parent fedb609 commit c63a63d
Show file tree
Hide file tree
Showing 17 changed files with 1,587 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op")
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ endif()
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_elewise_add_act_pass fuse_bn_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
lock_free_optimize_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/details/build_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
"fuse_relu_depthwise_conv_pass");
AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_,
"fuse_elewise_add_act_pass");
AppendPassWithCheck(strategy_.fuse_bn_act_ops_, "fuse_bn_act_pass");
// for single card training, fuse_all_reduce_ops is unnecessary.
// coalesce_grad_tensor_pass should be before of MultiDevPass.
AppendPassWithCheck(strategy_.fuse_all_reduce_ops_,
Expand Down Expand Up @@ -369,6 +370,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "fuse_bn_act_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_bn_act_pass is only supported on "
"GPU, skipped.";
continue;
}
} else if (pass->Type() == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(mkldnn_enabled_op_types_));
Expand All @@ -394,6 +401,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
USE_PASS(sync_batch_norm_pass);
USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(fuse_bn_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(reduce_mode_multi_devices_pass);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/details/build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct BuildStrategy {
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
bool fuse_elewise_add_act_ops_{false};
bool fuse_bn_act_ops_{false};
// Fuse_all_optimizer_ops and fuse_all_reduce_ops require that gradients
// should not be sparse types
boost::optional<bool> fuse_all_optimizer_ops_{false};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if(WITH_NGRAPH)
set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "")
endif()

cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )

Expand Down
332 changes: 332 additions & 0 deletions paddle/fluid/framework/ir/fuse_bn_act_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/fuse_bn_act_pass.h"
#include <algorithm>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif

namespace paddle {
namespace framework {
namespace ir {

void FuseBatchNormActPass::ApplyImpl(ir::Graph *graph) const {
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(7, 4, 1)
// forward
std::unordered_set<std::string> act_types = {"relu"};
graph = FuseBatchNormAct(graph, act_types);
// backward
std::unordered_set<std::string> act_grad_types = {"relu_grad"};
graph = FuseBatchNormActGrad(graph, act_grad_types);
#endif
#endif
}

// act(bn(x))
ir::Graph *FuseBatchNormActPass::FuseBatchNormAct(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument(
"The input graph of FuseBatchNormAct should not be nullptr."));
FusePassBase::Init("bn_act", graph);

GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("bn_act/x")
->AsInput()
->assert_is_op_input("batch_norm", "X")
->assert_var_dtype(proto::VarType::FP16);
patterns::BatchNormAct bn_act_pattern(gpd.mutable_pattern(), "bn_act");

bn_act_pattern(x, act_types);

int found_bn_act_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormAct fuse";
// BN inputs
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, bn_act_pattern);
// BN outputs
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, bn_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, bn_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act, act, bn_act_pattern);

std::string bn_x_n = subgraph.at(x)->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_variance_n = bn_variance->Name();
std::string bn_mean_n = bn_mean->Name();
std::string bn_mean_out_n = bn_mean_out->Name();
std::string bn_variance_out_n = bn_variance_out->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string bn_out_n = bn_out->Name();
std::string act_out_n = act_out->Name();

Node *fused_bn_act_node = CreateFusedBatchNormActNode(
g, act, batch_norm, bn_x_n, bn_scale_n, bn_bias_n, bn_variance_n,
bn_mean_n, bn_mean_out_n, bn_variance_out_n, bn_saved_variance_n,
bn_saved_mean_n, bn_reserve_space_n, act_out_n);

VLOG(4) << "\n\t " << bn_x_n << ", " << bn_scale_n << ", " << bn_bias_n
<< ", " << bn_variance_n << " and " << bn_mean_n << " -> "
<< batch_norm->Name() << " -> " << bn_mean_out_n << ", "
<< bn_variance_out_n << ", " << bn_saved_variance_n << ", "
<< bn_saved_mean_n << ", " << bn_reserve_space_n << " and "
<< bn_out_n << "\n"
<< "\t " << bn_out_n << " -> " << act->Name() << " -> "
<< act_out_n;

ReLinkNodes(g, bn_out, batch_norm, act, fused_bn_act_node);
found_bn_act_count++;
};

gpd(graph, handler);

AddStatis(found_bn_act_count);
return graph;
}

Node *FuseBatchNormActPass::CreateFusedBatchNormActNode(
Graph *g, const Node *act, const Node *bn, const std::string &bn_x_n,
const std::string &bn_scale_n, const std::string &bn_bias_n,
const std::string &bn_variance_n, const std::string &bn_mean_n,
const std::string &bn_mean_out_n, const std::string &bn_variance_out_n,
const std::string &bn_saved_variance_n, const std::string &bn_saved_mean_n,
const std::string &bn_reserve_space_n, const std::string &act_out_n) const {
OpDesc desc;
desc.SetInput("X", std::vector<std::string>({bn_x_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("Mean", std::vector<std::string>({bn_mean_n}));
desc.SetInput("Variance", std::vector<std::string>({bn_variance_n}));

desc.SetOutput("Y", std::vector<std::string>({act_out_n}));
desc.SetOutput("MeanOut", std::vector<std::string>({bn_mean_out_n}));
desc.SetOutput("VarianceOut", std::vector<std::string>({bn_variance_out_n}));
desc.SetOutput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetOutput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetOutput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetType("fused_batch_norm_act");

desc.SetAttr("act_type", act->Name());
// Set attrs
for (auto &n : {act->Op(), bn->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}

auto fused_bn_act_node = g->CreateOpNode(&desc);
return fused_bn_act_node;
}

// the backward of act(bn(x))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// bn_grad: in["X", "Y@GRAD", "Scale", "Bias", "SavedMean", "SavedVariance",
// "ReserveSpace"],
// out["X@GRAD", "Scale@GRAD", "Bias@GRAD"]
ir::Graph *FuseBatchNormActPass::FuseBatchNormActGrad(
ir::Graph *graph,
const std::unordered_set<std::string> &act_grad_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph,
platform::errors::InvalidArgument(
"The input graph of FuseBatchNormActGrad should not be nullptr."));
FusePassBase::Init("bn_act_grad", graph);

GraphPatternDetector gpd;
auto *d_act_out =
gpd.mutable_pattern()
->NewNode("bn_act_grad/x")
->AsInput()
->assert_is_ops_input(act_grad_types, GradVarName("Out"));
patterns::BatchNormActGrad bn_act_grad_pattern(gpd.mutable_pattern(),
"bn_act_grad");
bn_act_grad_pattern(d_act_out, act_grad_types);

int found_bn_act_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseBatchNormActGrad fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_grad, act_grad, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(batch_norm_grad, batch_norm_grad,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_itermediate_out, d_itermediate_out,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_x, bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(bn_reserve_space, bn_reserve_space,
bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_x, d_bn_x, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_scale, d_bn_scale, bn_act_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(d_bn_bias, d_bn_bias, bn_act_grad_pattern);

std::string d_act_out_n = subgraph.at(d_act_out)->Name(); // Y@GRAD
std::string act_out_n = act_out->Name(); // Y
std::string d_itermediate_out_n = d_itermediate_out->Name();
std::string bn_x_n = bn_x->Name();
std::string bn_scale_n = bn_scale->Name();
std::string bn_bias_n = bn_bias->Name();
std::string bn_saved_mean_n = bn_saved_mean->Name();
std::string bn_saved_variance_n = bn_saved_variance->Name();
std::string bn_reserve_space_n = bn_reserve_space->Name();
std::string d_bn_x_n = d_bn_x->Name();
std::string d_bn_scale_n = d_bn_scale->Name();
std::string d_bn_bias_n = d_bn_bias->Name();

OpDesc desc;
desc.SetType("fused_batch_norm_act_grad");
desc.SetInput("X", {bn_x_n});
desc.SetInput("Y", std::vector<std::string>({act_out_n}));
desc.SetInput(GradVarName("Y"), std::vector<std::string>({d_act_out_n}));
desc.SetInput("Scale", std::vector<std::string>({bn_scale_n}));
desc.SetInput("Bias", std::vector<std::string>({bn_bias_n}));
desc.SetInput("SavedMean", std::vector<std::string>({bn_saved_mean_n}));
desc.SetInput("SavedVariance",
std::vector<std::string>({bn_saved_variance_n}));
desc.SetInput("ReserveSpace",
std::vector<std::string>({bn_reserve_space_n}));
desc.SetOutput(GradVarName("X"), std::vector<std::string>({d_bn_x_n}));
desc.SetOutput(GradVarName("Scale"),
std::vector<std::string>({d_bn_scale_n}));
desc.SetOutput(GradVarName("Bias"),
std::vector<std::string>({d_bn_bias_n}));
std::string act = act_grad->Name();
act = act.substr(0, act.length() - 5); // remove "_grad"
desc.SetAttr("act_type", act);

for (auto &n : {act_grad->Op(), batch_norm_grad->Op()}) {
for (auto &m : n->GetAttrMap()) {
desc.SetAttr(m.first, m.second);
}
}

auto fused_node = g->CreateOpNode(&desc);

VLOG(4) << "\n\t " << d_act_out_n << " and " << act_out_n << " -> "
<< act_grad->Name() << " -> " << d_itermediate_out_n << "\n\t "
<< bn_x_n << ", " << d_itermediate_out_n << ", " << bn_scale_n
<< ", " << bn_bias_n << ", " << bn_saved_mean_n << ", "
<< bn_saved_variance_n << " and " << bn_reserve_space_n << " -> "
<< batch_norm_grad->Name() << " -> " << d_bn_x_n << ", "
<< d_bn_scale_n << " and " << d_bn_bias_n;

ReLinkNodes(g, d_itermediate_out, act_grad, batch_norm_grad, fused_node);
found_bn_act_count++;
};

gpd(graph, handler);

AddStatis(found_bn_act_count);
return graph;
}

void FuseBatchNormActPass::ReLinkNodes(Graph *graph,
const Node *intermediate_out, Node *op_1,
Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}

std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
// intermediate_out or ctr_var
auto result_iter =
std::find_if(op_2->inputs.begin(), op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });

if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
}

for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}

for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}

nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));

GraphSafeRemoveNodes(graph, nodes2delete);
}

std::vector<Node *> FuseBatchNormActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size());
bool has_replaced = false;
std::transform(nodes.begin(), nodes.end(), new_list.begin(),
[&](Node *node) -> Node * {
if (node == cur_node) {
has_replaced = true;
return new_node;
}
return node;
});
PADDLE_ENFORCE_EQ(has_replaced, true,
platform::errors::NotFound("Not find %s in the node list.",
cur_node->Name()));
return new_list;
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(fuse_bn_act_pass, paddle::framework::ir::FuseBatchNormActPass);
Loading

0 comments on commit c63a63d

Please sign in to comment.