Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4 from Yancey1989/fuse_broadcast
Browse files Browse the repository at this point in the history
Fuse broadcast
  • Loading branch information
typhoonzero committed Oct 12, 2018
2 parents 51ed711 + 9eebf83 commit 2124c76
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 14 deletions.
5 changes: 5 additions & 0 deletions benchmark/fluid/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,10 @@ def parse_args():
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--fuse_broadcast_op',
action='store_true',
help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
)
args = parser.parse_args()
return args
2 changes: 1 addition & 1 deletion benchmark/fluid/fluid_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
else:
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce
build_strategy.fuse_broadcast_op = args.fuse_broadcast_op

avg_loss = train_args[0]

Expand Down Expand Up @@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,

if args.use_fake_data or args.use_reader_op:
try:

fetch_ret = exe.run(fetch_list)
except fluid.core.EOFException as eof:
break
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ if(WITH_GPU)
dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)

else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif()

cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
Expand All @@ -34,7 +36,7 @@ if(WITH_GPU)
endif()

cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)

if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
Expand Down
21 changes: 14 additions & 7 deletions paddle/fluid/framework/details/broadcast_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ void BroadcastOpHandle::RunImpl() {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}

BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
}

void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);

InitOutputValue(*in_var_handle, out_var_handles);
InitOutputValue(in_var_handle, out_var_handles);

if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue;
}
auto &out_p = out_var_handle->place_;
Expand Down Expand Up @@ -114,12 +121,12 @@ void BroadcastOpHandle::RunImpl() {
}
}

if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
if (!out_handle->IsTheSameVar(in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle.scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
in_tensor, in_var_handle.place_,
*(dev_ctxes_.at(in_var_handle.place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;

private:
void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);

std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ struct BuildStrategy {

bool enable_data_balance_{false};

bool fuse_broadcast_op_{false};

int merge_batches_repeats_{1};

// User normally doesn't need to call this API.
Expand Down
55 changes: 55 additions & 0 deletions paddle/fluid/framework/details/fused_broadcast_op_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace framework {
namespace details {

void FusedBroadcastOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);

if (places_.size() == 1UL) return;

auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);

WaitInputVarGenerated();

std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}

size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());

for (size_t i = 0; i < in_var_handles.size(); ++i) {
BroadcastOneVar(
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
}
}

std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }

} // namespace details
} // namespace framework
} // namespace paddle
57 changes: 57 additions & 0 deletions paddle/fluid/framework/details/fused_broadcast_op_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace framework {
namespace details {

struct FusedBroadcastOpHandle : public BroadcastOpHandle {
public:
#ifdef PADDLE_WITH_CUDA
FusedBroadcastOpHandle(ir::Node *node,
const std::vector<Scope *> local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctx)
: BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
#else
FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
const std::vector<platform::Place>& places)
: BroadcastOpHandle(node, local_scopes, places) {}
#endif
std::string Name() const override;

protected:
void RunImpl() override;
};

} // namespace details
} // namespace framework
} // namespace paddle
51 changes: 47 additions & 4 deletions paddle/fluid/framework/details/multi_devices_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
Expand Down Expand Up @@ -436,10 +437,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(&result, bcast_var_name_set);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
}
}
Expand Down Expand Up @@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
}
}

void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast" m ir::Node::Type::kOperation),
local_scopes_, places_);
#endif
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);

for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
}

for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) {
auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id];
auto &vars =
result->Get<GraphVars>(kGraphVars).at(out_dev_id).at(p_name);
auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable),
vars.size(), out_dev_id, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
}
}

void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/details/multi_devices_graph_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;

void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;

bool IsSparseGradient(const std::string &og) const;

size_t GetAppropriateDeviceID(
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ All parameter, weight, gradient are variables in Paddle.
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property(
"fuse_broadcast_op",
[](const BuildStrategy &self) { return self.fuse_broadcast_op_; },
[](BuildStrategy &self, bool b) { self.fuse_broadcast_op_ = b; })
.def_property("fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
Expand Down

0 comments on commit 2124c76

Please sign in to comment.