Skip to content

Commit

Permalink
Merge pull request #9448 from typhoonzero/fix_dist_slr_height
Browse files Browse the repository at this point in the history
fix dist train selected rows height missing
  • Loading branch information
typhoonzero committed Mar 29, 2018
2 parents 24100e1 + 96192a8 commit d21ab2e
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 17 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
}

grpc::ChannelArguments args;
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 5000);
args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
args.SetMaxSendMessageSize(std::numeric_limits<int>::max());
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ message VariableMessage {
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 7;
bytes serialized = 8;
// selected_rows data
bytes rows = 8;
bytes rows = 9;
}

message VoidMessage {}

message TestMessage { int64 test_1 = 1; }
3 changes: 2 additions & 1 deletion paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kDimsFieldNumber, dim);
}
e.WriteUint64(VarMsg::kLodLevelFieldNumber, 0);
e.WriteUint64(VarMsg::kSlrHeightFieldNumber, slr->height());
auto* tensor = slr->mutable_value();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -154,7 +155,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e2((char*)buf, 128);
// NOTE: rows is of type int64_t
size_t rows_memory_size =
slr->rows().capacity() * framework::SizeOfType(typeid(int64_t));
slr->rows().size() * framework::SizeOfType(typeid(int64_t));
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <sys/time.h>
#include <iostream>
#include <string>
#include <vector>
Expand All @@ -35,6 +36,12 @@ namespace detail {
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"

static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}

typedef void (*DestroyCallback)(void*);

void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
Expand Down
21 changes: 13 additions & 8 deletions paddle/fluid/operators/detail/test_serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ void RunSerdeTestSelectedRows(platform::Place place) {
// serialize var to ByteBuffer
framework::Variable var;
auto* slr = var.GetMutable<framework::SelectedRows>();
slr->set_height(1000);
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
tensor->Resize(framework::make_ddim({2, 10}));
tensor->Resize(framework::make_ddim({564, 128}));
tensor->mutable_data<float>(place);
int tensor_numel = 2 * 10;
int tensor_numel = 564 * 128;
math::set_constant(ctx, tensor, 32.7);
rows->push_back(3);
rows->push_back(10);
for (int i = 0; i < 564; ++i) rows->push_back(i);

::grpc::ByteBuffer msg;
operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
Expand All @@ -64,6 +64,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
sendrecv::VariableMessage varmsg;
EXPECT_TRUE(varmsg.ParseFromString(tmp));

// deserialize bytebuffer
EXPECT_EQ(varmsg.varname(), "myvar");
EXPECT_EQ(varmsg.type(), 1);

Expand All @@ -74,8 +75,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
}
EXPECT_EQ(rows_data[0], 3);
EXPECT_EQ(rows_data[1], 10);
for (int i = 0; i < 564; ++i) {
EXPECT_EQ(rows_data[i], i);
}

// deserialize zero-copy
// framework::Variable var2;
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
Expand Down Expand Up @@ -104,8 +107,10 @@ void RunSerdeTestSelectedRows(platform::Place place) {
for (int i = 0; i < tensor_numel; ++i) {
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
}
EXPECT_EQ(rows_data2[0], 3);
EXPECT_EQ(rows_data2[1], 10);
for (int i = 0; i < rows2->size(); ++i) {
EXPECT_EQ(rows_data2[i], i);
}
EXPECT_EQ(slr2->height(), 1000);
}

void RunTestLodTensor(platform::Place place, int from_type = 0) {
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,13 @@ bool VariableResponse::CopySelectRowsTensorData(
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
auto* tensor = slr->mutable_value();
tensor->Resize(dims);
PADDLE_ENFORCE_EQ(
tensor->numel(),
length / framework::SizeOfType(
paddle::operators::detail::ToTypeIndex(meta_.data_type())));
void* tensor_data = tensor->mutable_data(
ctx.GetPlace(),
paddle::operators::detail::ToTypeIndex(meta_.data_type()));
Expand All @@ -165,7 +170,8 @@ bool VariableResponse::CopySelectRowsData(
const platform::DeviceContext& ctx, int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->resize(length / 8); // int64
slr->mutable_rows()->resize(length /
framework::SizeOfType(typeid(int64_t))); // int64
int64_t* rows_data = slr->mutable_rows()->data();

// copy rows CPU data, GPU data will be copied lazily.
Expand Down Expand Up @@ -348,6 +354,14 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kSlrHeightFieldNumber: {
uint64_t v = 0;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) {
return tag;
}
meta_.set_slr_height(static_cast<int64_t>(v));
break;
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
meta_.type() == sendrecv::LOD_TENSOR) &&
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ListenAndServOp : public framework::OperatorBase {
// and this will still work.

std::vector<std::future<void>> fs;
double ts = detail::GetTimestamp();
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(
Expand All @@ -162,6 +163,7 @@ class ListenAndServOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
}
VLOG(2) << "run all blocks spent (ms) " << detail::GetTimestamp() - ts;

// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class SendOp : public framework::OperatorBase {

for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(2) << "sending " << ins[i] << " to " << epmap[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
Expand All @@ -81,7 +81,7 @@ class SendOp : public framework::OperatorBase {
PADDLE_ENFORCE(rpc_client->Wait());

for (auto& ep : endpoints) {
VLOG(2) << "batch barrier, ep: " << ep;
VLOG(3) << "batch barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
Expand Down

0 comments on commit d21ab2e

Please sign in to comment.