From 94eea16e6de736f6b08c2cdf2b99c0057ccfbca5 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 3 Apr 2018 15:02:45 +0800 Subject: [PATCH 1/2] fix sendrecv port bind --- paddle/fluid/operators/detail/grpc_server.cc | 9 +- paddle/fluid/operators/detail/grpc_server.h | 3 + paddle/fluid/operators/listen_and_serv_op.cc | 272 ++++++++----------- paddle/fluid/operators/listen_and_serv_op.h | 85 ++++++ paddle/fluid/operators/send_recv_op_test.cc | 27 +- 5 files changed, 233 insertions(+), 163 deletions(-) create mode 100644 paddle/fluid/operators/listen_and_serv_op.h diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 7c978b28b6873..b81d37415848f 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -186,7 +186,8 @@ void AsyncGRPCServer::WaitClientGet(int count) { void AsyncGRPCServer::RunSyncUpdate() { ::grpc::ServerBuilder builder; - builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials()); + builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), + &selected_port_); builder.SetMaxSendMessageSize(std::numeric_limits::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); builder.RegisterService(&service_); @@ -196,7 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() { cq_prefetch_ = builder.AddCompletionQueue(); server_ = builder.BuildAndStart(); - LOG(INFO) << "Server listening on " << address_ << std::endl; + LOG(INFO) << "Server listening on " << address_ + << " selected port: " << selected_port_; std::function send_register = std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); @@ -242,6 +244,9 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } + while (scope_ == nullptr) { + sleep(0.01); + } RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index b0596d3cd1e10..014bbd0e7f4ab 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -62,6 +62,8 @@ class AsyncGRPCServer final { void SetExecutor(framework::Executor *executor) { executor_ = executor; } + int GetSelectedPort() { return selected_port_; } + const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } void Push(const std::string &msg_name) { @@ -109,6 +111,7 @@ class AsyncGRPCServer final { int prefetch_blk_id_; framework::ProgramDesc *program_; framework::Executor *executor_; + int selected_port_; }; }; // namespace detail diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index b19add24e2bd3..503ddd5d2f939 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -12,185 +12,145 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include +#include -#include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/grpc_server.h" +#include "paddle/fluid/operators/listen_and_serv_op.h" namespace paddle { namespace operators { -constexpr char kOptimizeBlock[] = "OptimizeBlock"; - void RunServer(std::shared_ptr service) { service->RunSyncUpdate(); VLOG(4) << "RunServer thread end"; } -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } +ListenAndServOp::ListenAndServOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + +int ListenAndServOp::GetSelectedPort() { + return rpc_service_->GetSelectedPort(); } -static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, - framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *scope) { - std::vector> fs; - for (size_t idx : parallel_blkids) { - fs.push_back(framework::Async([&executor, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - executor->Run(*program, scope, run_block, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); - } - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); +void ListenAndServOp::Stop() { + rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); + server_thread_->join(); } -class ListenAndServOp : public framework::OperatorBase { - public: - ListenAndServOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) { - if (!rpc_service_) { - std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); - server_thread_.reset(new std::thread(RunServer, rpc_service_)); - } - } +void ListenAndServOp::RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + framework::Scope &recv_scope = scope.NewScope(); + LOG(INFO) << "created recv scope: " << &recv_scope; - void Stop() override { - rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); - server_thread_->join(); + if (!rpc_service_) { + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); } - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - framework::Scope &recv_scope = scope.NewScope(); - - // FIXME(Yancey1989): initialize rpc server with lazy mode. - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - auto ins = Inputs("X"); - auto fan_in = Attr("Fanin"); - - auto *block = Attr(kOptimizeBlock); - auto *program = block->Program(); - int num_blocks = program->Size(); - PADDLE_ENFORCE_GE(num_blocks, 2, - "server program should have at least 2 blocks"); - - framework::Executor executor(dev_place); - - // TODO(qiao) set proper fields for table lookup and update - rpc_service_->SetExecutor(&executor); - rpc_service_->SetPrefetchBlkdId(0); - rpc_service_->SetProgram(program); - - // TODO(typhoonzero): change this to a while_op for every cluster-batch. - bool exit_flag = false; - // Record received sparse variables, so that - // we could reset those after execute optimize program - std::vector sparse_vars; - while (!exit_flag) { - // Get from multiple trainers, we don't care about the order in which - // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(0); - size_t recv_var_cnt = 0; - int batch_barrier = 0; - while (batch_barrier != fan_in) { - const detail::ReceivedMessage v = rpc_service_->Get(); - auto recv_var_name = v.first; - if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; - break; - } else if (recv_var_name == BATCH_BARRIER_MESSAGE) { - VLOG(3) << "recv batch barrier message"; - batch_barrier++; - continue; - } else { - VLOG(3) << "received grad: " << recv_var_name; - recv_var_cnt++; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - if (var->IsType()) { - sparse_vars.push_back(var); - } - } - } - if (exit_flag) { - rpc_service_->SetCond(1); - rpc_service_->ShutDown(); + auto ins = Inputs("X"); + auto fan_in = Attr("Fanin"); + auto *block = Attr(kOptimizeBlock); + auto *program = block->Program(); + size_t num_blocks = program->Size(); + PADDLE_ENFORCE_GE(num_blocks, 2, + "server program should have at least 2 blocks"); + + framework::Executor executor(dev_place); + + // FIXME(Yancey1989): initialize rpc server with lazy mode. + rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); + // TODO(qiao) set proper fields for table lookup and update + rpc_service_->SetExecutor(&executor); + rpc_service_->SetPrefetchBlkdId(0); + rpc_service_->SetProgram(program); + // start the server listening after all member initialized. + server_thread_.reset(new std::thread(RunServer, rpc_service_)); + // FIXME(typhoonzero): do we need to wait until the server port is ready? + sleep(5); + + // TODO(typhoonzero): change this to a while_op for every cluster-batch. + bool exit_flag = false; + // Record received sparse variables, so that + // we could reset those after execute optimize program + std::vector sparse_vars; + while (!exit_flag) { + // Get from multiple trainers, we don't care about the order in which + // the gradients arrives, just add suffix 0~n and merge the gradient. + rpc_service_->SetCond(0); + size_t recv_var_cnt = 0; + int batch_barrier = 0; + while (batch_barrier != fan_in) { + const detail::ReceivedMessage v = rpc_service_->Get(); + auto recv_var_name = v.first; + if (recv_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; break; - } - - // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads - // and this will still work. - - // The optimize blocks which have the same parent ID would run parallel - // TODO(Yancey1989): need to use ParallelExecutor for future - size_t last_parent_blkid = program->Block(1).Parent(); - std::vector parallel_blkids; - parallel_blkids.push_back(1); - double ts = detail::GetTimestamp(); - for (size_t blkid = 2; blkid < num_blocks; ++blkid) { - if (program->Block(blkid).Parent() != last_parent_blkid) { - for (size_t idx : parallel_blkids) VLOG(3) << idx; - ParallelExecuteBlocks(parallel_blkids, &executor, program, - &recv_scope); - parallel_blkids.clear(); - last_parent_blkid = program->Block(blkid).Parent(); + } else if (recv_var_name == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "recv batch barrier message"; + batch_barrier++; + continue; + } else { + VLOG(3) << "received grad: " << recv_var_name; + recv_var_cnt++; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + if (var->IsType()) { + sparse_vars.push_back(var); } - parallel_blkids.push_back(blkid); - } - ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); - - VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts - << "(ms)"; - - // Reset the received sparse variables, the sum operator would not - // sum the input sparse variables which rows is empty at the next - // mini-batch. - // TODO(Yancey1989): move the reset action into an operator, we couldn't - // have any hide logic in the operator. - for (auto &var : sparse_vars) { - var->GetMutable()->mutable_rows()->clear(); } + } + if (exit_flag) { rpc_service_->SetCond(1); - // FIXME(typhoonzero): use another condition to sync wait clients get. - rpc_service_->WaitClientGet(fan_in); - sparse_vars.clear(); - } // while(true) - } + rpc_service_->ShutDown(); + break; + } - protected: - std::shared_ptr rpc_service_; - std::shared_ptr server_thread_; -}; + // NOTE: if is_gpu_place, CUDA kernels are laugched by multiple threads + // and this will still work. + + // The optimize blocks which have the same parent ID would run parallel + // TODO(Yancey1989): need to use ParallelExecutor for future + int32_t last_parent_blkid = program->Block(1).Parent(); + std::vector parallel_blkids; + parallel_blkids.push_back(1); + double ts = detail::GetTimestamp(); + for (size_t blkid = 2; blkid < num_blocks; ++blkid) { + if (program->Block(blkid).Parent() != last_parent_blkid) { + for (size_t idx : parallel_blkids) VLOG(3) << idx; + ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); + parallel_blkids.clear(); + last_parent_blkid = program->Block(blkid).Parent(); + } + parallel_blkids.push_back(blkid); + } + ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); + + VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)"; + + // Reset the received sparse variables, the sum operator would not + // sum the input sparse variables which rows is empty at the next + // mini-batch. + // TODO(Yancey1989): move the reset action into an operator, we couldn't + // have any hide logic in the operator. + for (auto &var : sparse_vars) { + var->GetMutable()->mutable_rows()->clear(); + } + rpc_service_->SetCond(1); + // FIXME(typhoonzero): use another condition to sync wait clients get. + rpc_service_->WaitClientGet(fan_in); + sparse_vars.clear(); + } // while(true) +} class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { public: diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h new file mode 100644 index 0000000000000..4ebb8c3840e3c --- /dev/null +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/detail/grpc_server.h" + +namespace paddle { +namespace operators { + +constexpr char kOptimizeBlock[] = "OptimizeBlock"; + +void RunServer(std::shared_ptr service); + +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VariableMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + +static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, + framework::Executor *executor, + framework::ProgramDesc *program, + framework::Scope *scope) { + std::vector> fs; + for (size_t idx : parallel_blkids) { + fs.push_back(framework::Async([&executor, &program, &scope, idx]() { + int run_block = idx; // thread local + try { + executor->Run(*program, scope, run_block, false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); + } + for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); +} + +class ListenAndServOp : public framework::OperatorBase { + public: + ListenAndServOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs); + + int GetSelectedPort(); + + void Stop() override; + + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override; + + protected: + mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr server_thread_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index 04392b3e05fa2..542bc3fde2a36 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/string/printf.h" @@ -34,6 +35,7 @@ namespace m = paddle::operators::math; // global for simplicity. std::unique_ptr listen_and_serv_op; +int selected_port; void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { p::CPUDeviceContext ctx(place); @@ -128,14 +130,16 @@ void StartServerNet(bool is_sparse) { AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); f::AttributeMap attrs; - attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); + attrs.insert({"endpoint", std::string("127.0.0.1:0")}); attrs.insert({"Fanin", 1}); attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); attrs.insert({"OptimizeBlock", optimize_block}); listen_and_serv_op = f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); + LOG(INFO) << "selected port before run " << selected_port; listen_and_serv_op->Run(scope, place); + LOG(INFO) << "server exit"; } TEST(SendRecvOp, CPUDense) { @@ -149,12 +153,19 @@ TEST(SendRecvOp, CPUDense) { scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; - attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); - attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); + selected_port = static_cast( + listen_and_serv_op.get()) + ->GetSelectedPort(); + LOG(INFO) << "selected port " << selected_port; + std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); + attrs.insert({"endpoints", std::vector({endpoint})}); + attrs.insert({"epmap", std::vector({endpoint})}); auto send_op = f::OpRegistry::CreateOp( "send", {{"X", {"x1"}}}, {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); + LOG(INFO) << "before run " << endpoint; send_op->Run(scope, place); + LOG(INFO) << "end run"; auto in_var = scope.Var("x1"); auto tensor = in_var->GetMutable(); @@ -167,6 +178,7 @@ TEST(SendRecvOp, CPUDense) { for (int64_t i = 0; i < target->numel(); ++i) { EXPECT_EQ(expected[i] * 2, actual[i]); } + LOG(INFO) << "before stop"; listen_and_serv_op->Stop(); server_thread.join(); listen_and_serv_op.reset(nullptr); @@ -182,8 +194,13 @@ TEST(SendRecvOp, CPUSparse) { InitSelectedRowsInScope(scope, place); scope.Var("RPC_CLIENT_VAR"); f::AttributeMap attrs; - attrs.insert({"endpoints", std::vector({"127.0.0.1:6174"})}); - attrs.insert({"epmap", std::vector({"127.0.0.1:6174"})}); + selected_port = static_cast( + listen_and_serv_op.get()) + ->GetSelectedPort(); + LOG(INFO) << "selected port " << selected_port; + std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); + attrs.insert({"endpoints", std::vector({endpoint})}); + attrs.insert({"epmap", std::vector({endpoint})}); auto send_op = f::OpRegistry::CreateOp( "send", {{"X", {"x1"}}}, {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); From 00f8e63b8d4577c77ebafb14909e9ca20e57bb3b Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 3 Apr 2018 17:35:02 +0800 Subject: [PATCH 2/2] update --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/detail/grpc_server.cc | 3 -- paddle/fluid/operators/listen_and_serv_op.cc | 33 +++++++++++++++++++- paddle/fluid/operators/listen_and_serv_op.h | 32 ------------------- 4 files changed, 33 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 9ed79453b962b..952ac8b1dcf92 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -193,6 +193,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) else() set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b81d37415848f..1515004d97956 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -244,9 +244,6 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } - while (scope_ == nullptr) { - sleep(0.01); - } RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 503ddd5d2f939..611457e6d6c72 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -25,6 +25,38 @@ void RunServer(std::shared_ptr service) { VLOG(4) << "RunServer thread end"; } +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VariableMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + +static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, + framework::Executor *executor, + framework::ProgramDesc *program, + framework::Scope *scope) { + std::vector> fs; + for (size_t idx : parallel_blkids) { + fs.push_back(framework::Async([&executor, &program, &scope, idx]() { + int run_block = idx; // thread local + try { + executor->Run(*program, scope, run_block, false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + })); + } + for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); +} + ListenAndServOp::ListenAndServOp(const std::string &type, const framework::VariableNameMap &inputs, const framework::VariableNameMap &outputs, @@ -62,7 +94,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, framework::Executor executor(dev_place); - // FIXME(Yancey1989): initialize rpc server with lazy mode. rpc_service_->SetScope(&recv_scope); rpc_service_->SetDevCtx(&dev_ctx); // TODO(qiao) set proper fields for table lookup and update diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 4ebb8c3840e3c..0da87afc961e8 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -30,38 +30,6 @@ constexpr char kOptimizeBlock[] = "OptimizeBlock"; void RunServer(std::shared_ptr service); -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } -} - -static void ParallelExecuteBlocks(const std::vector ¶llel_blkids, - framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *scope) { - std::vector> fs; - for (size_t idx : parallel_blkids) { - fs.push_back(framework::Async([&executor, &program, &scope, idx]() { - int run_block = idx; // thread local - try { - executor->Run(*program, scope, run_block, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - })); - } - for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); -} - class ListenAndServOp : public framework::OperatorBase { public: ListenAndServOp(const std::string &type,