Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run prefetch prog on server #9593

Merged
merged 7 commits into from
Apr 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"

#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
Expand All @@ -39,6 +38,7 @@ Scope::~Scope() {
}

Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
Expand Down Expand Up @@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}

void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
Expand All @@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}

void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/framework/scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include <list>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);

void EraseVars(std::vector<std::string>& var_names);
void EraseVars(const std::vector<std::string>& var_names);

/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
Expand Down Expand Up @@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr};

DISABLE_COPY_AND_ASSIGN(Scope);

private:
mutable std::mutex mutex_;
};
} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
endif()
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto* var = p_scope->FindVar(in_var_name_val);

::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);

// var handle
VarHandle var_h;
Expand Down
27 changes: 18 additions & 9 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program, int blkid)
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
blkid_(blkid) {
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}

virtual ~RequestPrefetch() {}

virtual std::string GetReqName() { return request_.varname(); }
virtual std::string GetReqName() { return request_->Varname(); }

virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops

VLOG(3) << "RequestPrefetch Process in";
std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);

SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);

responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}

protected:
sendrecv::VariableMessage request_;
std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
};

Expand Down Expand Up @@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_);
executor_, program_, prefetch_ctx_);

VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class AsyncGRPCServer final {

void SetExecutor(framework::Executor *executor) { executor_ = executor; }

void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}

int GetSelectedPort() { return selected_port_; }

const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
Expand Down Expand Up @@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_prefetch_;

int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
Expand Down
100 changes: 89 additions & 11 deletions paddle/fluid/operators/detail/grpc_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,121 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;

USE_OP(lookup_table);

std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;

framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);

framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});

auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS);
out.SetShape({10, 10});

return block;
}

void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();

auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>();

auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>();
}

void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
auto rows = ids_var->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
ids_var->mutable_value()->Resize({rows_numel, 1});
ids_var->mutable_value()->mutable_data<float>(*place);
}

void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto rows = w->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});

auto ptr = w_value->mutable_data<float>(*place);

for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}

void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);

rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);

rpc_service_->RunSyncUpdate();
}

TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string in_var_name("in");
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);

scope.Var(out_var_name);

VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();

auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);

rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);

for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
}
4 changes: 3 additions & 1 deletion paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
}

Expand Down Expand Up @@ -67,6 +67,8 @@ message VariableMessage {
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
}

message VoidMessage {}
6 changes: 5 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace detail {

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request;
std::string header;
Expand All @@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}

if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());

void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}

std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}

meta_.set_out_varname(temp);
break;
}

default: {
// Unknown tag, return unknown error.
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/detail/variable_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer);

inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }

// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
Expand Down