Skip to content

Commit

Permalink
cherrypick grpc fixes (#11692)
Browse files Browse the repository at this point in the history
  • Loading branch information
typhoonzero committed Jun 25, 2018
1 parent 5778040 commit d2d6e8f
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 38 deletions.
6 changes: 3 additions & 3 deletions cmake/external/grpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ ExternalProject_Add(
# NOTE(wuyi):
# this package is generated by following steps:
# 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
# 2. submodule update --init
# 2. git submodule update --init
# 3. keep only zlib, cares, protobuf, boringssl under "third_party",
# checkout and clean other dirs under third_party
# 4. remove .git, and package the directory.
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.8.x.tar.gz"
URL_MD5 "c9c58ee7d0e8929a63155af6a2ecdbd0"
URL "http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x.tar.gz"
URL_MD5 "1f268a2aff6759839dccd256adcc91cf"
PREFIX ${GRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/distributed/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,15 @@ void GRPCClient::Proceed() {
}

std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::lock_guard<std::mutex> guard(chan_mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}

// Channel configurations:
grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
Expand Down
20 changes: 10 additions & 10 deletions paddle/fluid/operators/distributed/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class BaseProcessor {
virtual void Prepare(const VarHandle& var_info, int64_t time_out) {
context_.reset(new grpc::ClientContext());
var_h_ = var_info;
context_->set_wait_for_ready(true);

std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
Expand All @@ -81,6 +82,7 @@ class BaseProcessor {

virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());
context_->set_wait_for_ready(true);

std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
Expand Down Expand Up @@ -172,26 +174,24 @@ class GRPCClient : public RPCClient {

bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;

bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;

bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override;
int64_t time_out = FLAGS_grpc_deadline) override;

void AsyncSendBatchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;

void AsyncSendFetchBarrier(
const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override;

void Wait() override;

Expand All @@ -207,7 +207,7 @@ class GRPCClient : public RPCClient {
void Proceed();

void AsyncSendComplete(const std::string& ep,
int64_t time_out = RPCClient::rpc_time_out);
int64_t time_out = FLAGS_grpc_deadline);

std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);

Expand Down
20 changes: 10 additions & 10 deletions paddle/fluid/operators/distributed/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class RequestSend final : public RequestBase {

void Process() override {
std::string varname = GetReqName();
VLOG(3) << "RequestSend var_name:" << varname;
VLOG(4) << "RequestSend var_name:" << varname;

auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
Expand Down Expand Up @@ -119,7 +119,7 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
VLOG(3) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << varname;

auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
Expand Down Expand Up @@ -165,7 +165,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process...
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
VLOG(3) << "RequestPrefetch, in_var_name: " << in_var_name
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;

auto scope = request_->GetMutableLocalScope();
Expand All @@ -188,10 +188,10 @@ class RequestPrefetch final : public RequestBase {
};

void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
VLOG(4) << "AsyncGRPCServer WaitSeverReady";
}

void AsyncGRPCServer::StartServer() {
Expand Down Expand Up @@ -230,7 +230,7 @@ void AsyncGRPCServer::StartServer() {
for (int i = 0; i < threadnum; i++) {
rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
&AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
VLOG(3) << t.first << " creates threads!";
VLOG(4) << t.first << " creates threads!";
}
}

Expand All @@ -247,15 +247,15 @@ void AsyncGRPCServer::StartServer() {
auto& threads = t.second;
for (size_t i = 0; i < threads.size(); ++i) {
threads[i]->join();
VLOG(3) << t.first << " threads ends!";
VLOG(4) << t.first << " threads ends!";
}
}
}

void AsyncGRPCServer::ShutdownQueue() {
for (auto& t : rpc_cq_) {
t.second->Shutdown();
VLOG(3) << t.first << " shutdown!";
VLOG(4) << t.first << " queue shutdown!";
}
}

Expand All @@ -264,15 +264,15 @@ void AsyncGRPCServer::ShutDownImpl() {
is_shut_down_ = true;
ShutdownQueue();

VLOG(3) << "server_ shutdown!";
VLOG(4) << "server_ shutdown!";
server_->Shutdown();
}

void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
VLOG(4) << "shutdown, do not TryToRegisterNewSendOne";
return;
}

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/distributed/rpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
// limitations under the License.

#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "gflags/gflags.h"

// default to 3min to avoid temprary network failures.
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc");

namespace paddle {
namespace operators {
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/distributed/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
#pragma once

#include <string>
#include "gflags/gflags.h"

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"

DECLARE_int32(grpc_deadline);

namespace paddle {
namespace operators {
namespace distributed {
Expand All @@ -32,26 +35,26 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;

virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;

virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& in_var_name,
const std::string& out_var_name,
int64_t time_out = rpc_time_out) = 0;
int64_t time_out = FLAGS_grpc_deadline) = 0;

virtual void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendBatchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;

virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = rpc_time_out) = 0;
virtual void AsyncSendFetchBarrier(
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0;

// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
Expand All @@ -60,8 +63,6 @@ class RPCClient {

virtual void Wait() = 0;

static constexpr int64_t rpc_time_out = 120 * 1000;

template <typename T>
static RPCClient* GetInstance() {
std::call_once(init_flag_, &RPCClient::Init<T>);
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/distributed/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ void RPCServer::WaitBarrier(const std::string& rpc_name) {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});

VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
VLOG(3) << "batch_barrier_: " << rpc_name << " "
<< barrier_counter_[rpc_name];
}

void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
Expand Down Expand Up @@ -100,7 +101,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
}

void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
VLOG(4) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ void ListenAndServOp::RunSyncLoop(

void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
Expand Down Expand Up @@ -203,7 +202,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);

VLOG(3) << "RunAsyncLoop into while";
while (true) {
if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!";
Expand Down

0 comments on commit d2d6e8f

Please sign in to comment.