Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix distribute transpiler GRPC error code 4, RPC Deadline #18984

Merged
merged 5 commits into from Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc
Expand Up @@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));

PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");

std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}

for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};

Expand Down
22 changes: 7 additions & 15 deletions paddle/fluid/operators/distributed_ops/recv_op.cc
Expand Up @@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");

auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");

Expand All @@ -64,32 +64,28 @@ class RecvOp : public framework::OperatorBase {
trainer_id);
recv_functor(rpc_ctx, scope);
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
}
}
}
Expand All @@ -112,10 +108,6 @@ This operator can get variables from server side.
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/operators/distributed_ops/send_barrier_op.cc
Expand Up @@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase {

VLOG(3) << "SendBarrierOp sync";

// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
std::vector<distributed::VarHandlePtr> rets;

for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}

for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};

Expand Down
15 changes: 4 additions & 11 deletions paddle/fluid/operators/distributed_ops/send_op.cc
Expand Up @@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X");

auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id");

auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
Expand Down Expand Up @@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
Expand All @@ -98,10 +95,6 @@ Send operator

This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/fluid/transpiler/distribute_transpiler.py
Expand Up @@ -574,8 +574,7 @@ def transpile(self,
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode,
]
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
Expand All @@ -595,7 +594,6 @@ def transpile(self,
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
Expand Down Expand Up @@ -669,8 +667,7 @@ def transpile(self,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
[param_varname, recv_op_role_var_name]
})

if self.sync_mode:
Expand Down Expand Up @@ -1548,7 +1545,6 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
if self.sync_mode else []
},
attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
Expand Down