Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run prefetch prog on server #9593

Merged
merged 7 commits into from
Apr 8, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
endif()
2 changes: 1 addition & 1 deletion paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto* var = p_scope->FindVar(in_var_name_val);

::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);

// var handle
VarHandle var_h;
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,12 @@ class RequestPrefetch final : public RequestBase {
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops

VLOG(3) << "RequestPrefetch Process in";
executor_->Run(*program_, scope_, blkid_, false, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Process runs in a separated thread, executor_ may be accessed in different threads at the same time, don't know whether this is safe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the prefetch and optimize may happen at the same time, they will both access the lookup_table parameter.

I think the final solution may be that table optimization also be a separate thread, and the prefetch thread and update thread try to get the same lock.

Currently, we will run update operators withing optimize block, should find a way to a avoid the conflict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the prefetch and optimize may happen at the same time

May not, for the current process, prefetch request would happen before sending gradients. And there is a SEND BARRIER to make sure that optimize process would happen after prefetch request

Copy link
Contributor Author

@Yancey1989 Yancey1989 Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think it's not threaded safe for the current implementation because we use one scope to create the output variable, if there are more then two prefetch request, the output variable would be replaced, and the serialize function would be failed.
Maybe a way to solve this is to use the different scope to create output var.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can create a sub_scope here to store the output variable.

Copy link
Contributor Author

@Yancey1989 Yancey1989 Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But NewScope may also not thread safe...
Maybe another way is to create multiple output vars with the different suffix such as out_trainer0, out_trainer1 in Distributed transpiler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discuss with @Yancey1989, we decided to use NewScope to run each Process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need DeSerialize the Request into the current scope before running prefetch block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::string var_name = request_.out_varname();
auto* var = scope_->FindVar(var_name);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);

responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
Expand Down
79 changes: 71 additions & 8 deletions paddle/fluid/operators/detail/grpc_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,106 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;

USE_OP(lookup_table);

std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;

framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) {
const auto &root_block = program.Block(0);
auto *block= program.AppendBlock(root_block);

framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
return block;
}

void InitTensorsInScope(framework::Scope &scope, platform::CPUPlace &place) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should split InitTensorsInScope into InitTensorsInClientScope and InitTensorsInServerScope

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto w_var = scope.Var("w");
auto w = w_var->GetMutable<framework::LoDTensor>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

W should be SelectedRows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

w->Resize({10, 10});
float *ptr = w->mutable_data<float>(place);
for (int64_t i = 0; i < w->numel(); ++i) {
ptr[i] = static_cast<float>(i/10);
}

auto out_var = scope.Var("out");
auto out = out_var->GetMutable<framework::LoDTensor>();
out->Resize({5, 10});
out->mutable_data<float>(place);

auto ids_var = scope.Var("ids");
auto ids = ids_var->GetMutable<framework::LoDTensor>();
ids->Resize({5, 1});
auto ids_ptr = ids->mutable_data<int64_t>(place);
for (int64_t i = 0; i < ids->numel(); ++i) {
ids_ptr[i] = i * 2;
}
}


void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(program);
InitTensorsInScope(scope, place);

rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchBlkdId(block->ID());
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);

rpc_service_->RunSyncUpdate();
}


TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(3);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string in_var_name("in");
InitTensorsInScope(scope, place);
std::string in_var_name("ids");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);

scope.Var(out_var_name);

VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();

auto out_var = scope.Var(out_var_name);
auto out = out_var->Get<framework::LoDTensor>();
auto out_ptr = out.data<float>();
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);

EXPECT_EQ(out.dims().size(), 2);
EXPECT_EQ(out_ptr[0], static_cast<float>(0));
EXPECT_EQ(out_ptr[0 + 1 * out.dims()[1]], static_cast<float>(2));
EXPECT_EQ(out_ptr[0 + 2 * out.dims()[1]], static_cast<float>(4));
EXPECT_EQ(out_ptr[0 + 3 * out.dims()[1]], static_cast<float>(6));
EXPECT_EQ(out_ptr[0 + 4 * out.dims()[1]], static_cast<float>(8));
}
2 changes: 2 additions & 0 deletions paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ message VariableMessage {
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// prefetch var name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look up table block execution output variable name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems not updated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry...updated by the comments.

string out_varname = 10;
}

message VoidMessage {}
6 changes: 5 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace detail {

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request;
std::string header;
Expand All @@ -50,6 +51,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}

if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());

void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
Expand Down