Skip to content

Commit

Permalink
Merge pull request #9593 from Yancey1989/prefech_prog_on_server
Browse files Browse the repository at this point in the history
run prefetch prog on server
  • Loading branch information
Yancey committed Apr 8, 2018
2 parents 7d39725 + 974b253 commit be85385
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 28 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"

#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
Expand All @@ -39,6 +38,7 @@ Scope::~Scope() {
}

Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
Expand Down Expand Up @@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}

void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
Expand All @@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}

void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include <list>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);

void EraseVars(std::vector<std::string>& var_names);
void EraseVars(const std::vector<std::string>& var_names);

/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Expand Down Expand Up @@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr};

DISABLE_COPY_AND_ASSIGN(Scope);

private:
mutable std::mutex mutex_;
};
} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
endif()
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto* var = p_scope->FindVar(in_var_name_val);

::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);

// var handle
VarHandle var_h;
Expand Down
27 changes: 18 additions & 9 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program, int blkid)
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
blkid_(blkid) {
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}

virtual ~RequestPrefetch() {}

virtual std::string GetReqName() { return request_.varname(); }
virtual std::string GetReqName() { return request_->Varname(); }

virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops

VLOG(3) << "RequestPrefetch Process in";
std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);

SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);

responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}

protected:
sendrecv::VariableMessage request_;
std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
};

Expand Down Expand Up @@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_);
executor_, program_, prefetch_ctx_);

VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class AsyncGRPCServer final {

void SetExecutor(framework::Executor *executor) { executor_ = executor; }

void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}

int GetSelectedPort() { return selected_port_; }

const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
Expand Down Expand Up @@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_prefetch_;

int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
Expand Down
100 changes: 89 additions & 11 deletions paddle/fluid/operators/detail/grpc_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,121 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;

USE_OP(lookup_table);

std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;

framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);

framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});

auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS);
out.SetShape({10, 10});

return block;
}

void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();

auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>();

auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>();
}

void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
auto rows = ids_var->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
ids_var->mutable_value()->Resize({rows_numel, 1});
ids_var->mutable_value()->mutable_data<float>(*place);
}

void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto rows = w->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});

auto ptr = w_value->mutable_data<float>(*place);

for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}

void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);

rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);

rpc_service_->RunSyncUpdate();
}

TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string in_var_name("in");
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);

scope.Var(out_var_name);

VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();

auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);

rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);

for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
}
4 changes: 3 additions & 1 deletion paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
}

Expand Down Expand Up @@ -67,6 +67,8 @@ message VariableMessage {
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
}

message VoidMessage {}
6 changes: 5 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace detail {

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request;
std::string header;
Expand All @@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}

if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());

void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}

std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}

meta_.set_out_varname(temp);
break;
}

default: {
// Unknown tag, return unknown error.
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/detail/variable_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer);

inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }

// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
Expand Down

0 comments on commit be85385

Please sign in to comment.