From d93959f0fb6580af95aea2645f99b163d69c82f5 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 30 Jan 2018 19:07:40 +0800 Subject: [PATCH 1/2] perf enhance reuse connection --- paddle/operators/send_op.cc | 27 +++++++++++++------ .../paddle/v2/fluid/distribute_transpiler.py | 9 ++++++- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index bb719dc2a8a57..0be3b37859508 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" +#include #include #include "paddle/operators/detail/grpc_client.h" @@ -42,28 +43,35 @@ class SendOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); + + auto client_var_name = Output("RPCClient"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), + "Can not find variable '%s' in the scope.", + client_var_name); + auto* client_var = scope.FindVar(client_var_name); + detail::RPCClient* rpc_client = client_var->GetMutable(); + for (size_t i = 0; i < ins.size(); i++) { VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; - client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); + rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } - PADDLE_ENFORCE(client_.Wait()); + PADDLE_ENFORCE(rpc_client->Wait()); for (auto& ep : endpoints) { VLOG(3) << "batch barrier, ep: " << ep; - client_.AsyncSendBatchBarrier(ep); + rpc_client->AsyncSendBatchBarrier(ep); } - PADDLE_ENFORCE(client_.Wait()); + PADDLE_ENFORCE(rpc_client->Wait()); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; - client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - - PADDLE_ENFORCE(client_.Wait()); + PADDLE_ENFORCE(rpc_client->Wait()); } private: - mutable detail::RPCClient client_; + // mutable detail::RPCClient client_; }; class SendOpMaker : public framework::OpProtoAndCheckerMaker { @@ -73,6 +81,9 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to be received from server") .AsDuplicable(); + AddOutput("RPCClient", + "(RPCClient) The RPC client object which is" + "initialized at most once."); AddComment(R"DOC( Send operator diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 77f80442e06cb..a4464a281aae7 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -153,11 +153,18 @@ def transpile(self, self.param_grad_ep_mapping[ep]["params"].append(param) self.param_grad_ep_mapping[ep]["grads"].append(grad) + rpc_client_var = program.global_block().create_var( + name="RPC_CLIENT_VAR", + psersistable=True, + dtype='float32', # dtype and shape is not used in fact + shape=[0]) + # create send_op send_op = program.global_block().append_op( type="send", inputs={"X": send_inputs}, - outputs={"Out": send_outputs}, + outputs={"Out": send_outputs, + "RPCClient": rpc_client_var}, attrs={"endpoints": pserver_endpoints, "epmap": eplist}) # step4 From 683c5a3eb58ad3c75644a822b2e159e6b37b5b49 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 30 Jan 2018 19:09:19 +0800 Subject: [PATCH 2/2] clean up code --- paddle/operators/send_op.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 0be3b37859508..be41b527f2289 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" -#include #include #include "paddle/operators/detail/grpc_client.h" @@ -69,9 +68,6 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(rpc_client->Wait()); } - - private: - // mutable detail::RPCClient client_; }; class SendOpMaker : public framework::OpProtoAndCheckerMaker {