Skip to content

Commit

Permalink
rename rpc_token to TransportToken (#5735)
Browse files Browse the repository at this point in the history
* rename rpc_token to TransportToken

* minor fix

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
lixinqi and oneflow-ci-bot committed Aug 5, 2021
1 parent 9308f54 commit 37df98a
Show file tree
Hide file tree
Showing 26 changed files with 602 additions and 580 deletions.
6 changes: 3 additions & 3 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMe

Maybe<void> CheckConsistentTensorMeta(const one::Tensor& tensor, int64_t seconds) {
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(tensor));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
JUST(ctx->Check());
return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -412,9 +412,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
.def("rpc_token",
.def("consistent_id",
[](const one::Tensor& tensor) -> int64_t {
return static_cast<uint64_t>(tensor.rpc_token().GetOrThrow());
return static_cast<uint64_t>(tensor.transport_token().GetOrThrow());
})
.def("check_meta_consistency",
[](const one::Tensor& tensor) {
Expand Down
18 changes: 9 additions & 9 deletions oneflow/api/python/rpc/consistent_rpc_token_scope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace oneflow {

namespace {

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid,
Symbol<RankGroup> rank_group) {
Maybe<void> InitConsistentTransportTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid,
Symbol<RankGroup> rank_group) {
JUST(SetThisThreadConsistentUniqueId(thread_consistent_uid, thread_tag));
static thread_local const auto& init_rank_group_scope =
JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group));
Expand All @@ -40,21 +40,21 @@ Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
return Maybe<void>::Ok();
}

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid) {
Maybe<void> InitConsistentTransportTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
JUST(InitConsistentRpcTokenScope(thread_tag, thread_consistent_uid, rank_group));
JUST(InitConsistentTransportTokenScope(thread_tag, thread_consistent_uid, rank_group));
return Maybe<void>::Ok();
}

void ApiInitDefaultConsistentRpcTokenScope() {
return InitConsistentRpcTokenScope("main", 0).GetOrThrow();
void ApiInitDefaultConsistentTransportTokenScope() {
return InitConsistentTransportTokenScope("main", 0).GetOrThrow();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitDefaultConsistentRpcTokenScope", &ApiInitDefaultConsistentRpcTokenScope);
m.def("InitDefaultConsistentTransportTokenScope", &ApiInitDefaultConsistentTransportTokenScope);
}

} // namespace oneflow
4 changes: 2 additions & 2 deletions oneflow/api/python/rpc/rank_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace {

Maybe<void> CheckCurrentRankGroupConsistency(int64_t seconds) {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
const auto& ctx = JUST(CheckRpcToken(rank_group));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
const auto& ctx = JUST(CheckTransportToken(rank_group));
JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
return Maybe<void>::Ok();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
for (int i = 0; i < outputs->size(); ++i) {
const auto& tensor_impl = JUST(EagerConsistentTensorImpl::New(output_tensor_metas.at(i), device,
parallel_id, false, false));
const auto& rpc_token = JUST(RpcToken::NewMetaRpcToken());
JUST(tensor_impl->set_rpc_token(rpc_token));
const auto& transport_token = JUST(TransportToken::NewMetaTransportToken());
JUST(tensor_impl->set_transport_token(transport_token));
outputs->at(i).reset(new ConsistentTensor(tensor_impl));
}
// Do nothing if the `parallel_desc` doesn't cover current ProcessCtx.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ Maybe<void> EagerMirroredInterpreter::ApplyImpl(const CastToConsistentOpExpr& op
const auto& consistent_tensor_impl = JUST(EagerConsistentTensorImpl::New(
SymbolOf(tensor_meta), device, parallel_id, input_mirrored_tensor->requires_grad(),
!input_mirrored_tensor->requires_grad()));
const auto& rpc_token = JUST(RpcToken::NewMetaRpcToken());
JUST(consistent_tensor_impl->set_rpc_token(rpc_token));
const auto& transport_token = JUST(TransportToken::NewMetaTransportToken());
JUST(consistent_tensor_impl->set_transport_token(transport_token));
consistent_tensor = std::make_shared<ConsistentTensor>(consistent_tensor_impl);
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(*consistent_tensor));
if (parallel_id.has_value()) {
Expand All @@ -269,7 +269,7 @@ Maybe<void> EagerMirroredInterpreter::ApplyImpl(const CastToConsistentOpExpr& op
consistent_tensor_impl->reset_cur_rank_phy_tensor(
std::dynamic_pointer_cast<MirroredTensor>(synced_tensor));
}
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, RpcUtil::TimeoutSeconds()));
JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, TransportUtil::TimeoutSeconds()));
JUST(ctx->Check());
}
outputs->at(0) = consistent_tensor;
Expand Down
15 changes: 8 additions & 7 deletions oneflow/core/framework/rank_group_rpc_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <chrono>
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/framework/rpc_util.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/job/rank_group.h"
Expand All @@ -27,9 +27,9 @@ limitations under the License.

namespace oneflow {

Maybe<NaiveAsyncRpcCtx> CheckRpcToken(Symbol<RankGroup> rank_group) {
const auto& rpc_token =
JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdCheckRankGroupConsistency));
Maybe<NaiveAsyncTransportCtx> CheckTransportToken(Symbol<RankGroup> rank_group) {
const auto& transport_token =
JUST(TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdCheckRankGroupConsistency));
const auto& PrepareBuffer = [](void** buffer, std::size_t* size,
std::function<void()>* Callback) -> Maybe<void> {
const auto& placeholder = std::make_shared<uint32_t>();
Expand All @@ -38,9 +38,10 @@ Maybe<NaiveAsyncRpcCtx> CheckRpcToken(Symbol<RankGroup> rank_group) {
*Callback = [placeholder]() {};
return Maybe<void>::Ok();
};
const auto& ctx = std::make_shared<NaiveAsyncRpcCtx>(rpc_token, PrepareBuffer, PrepareBuffer);
JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, ctx.get()));
JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, ctx.get()));
const auto& ctx =
std::make_shared<NaiveAsyncTransportCtx>(transport_token, PrepareBuffer, PrepareBuffer);
JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get()));
JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get()));
return ctx;
}

Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/framework/rank_group_rpc_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_
#define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_

#include "oneflow/core/framework/rpc_token.h"
#include "oneflow/core/framework/rpc_util.h"
#include "oneflow/core/framework/transport_token.h"
#include "oneflow/core/framework/transport_util.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/job/rank_group.h"

namespace oneflow {

Maybe<NaiveAsyncRpcCtx> CheckRpcToken(Symbol<RankGroup> rank_group);
Maybe<NaiveAsyncTransportCtx> CheckTransportToken(Symbol<RankGroup> rank_group);

Maybe<int64_t> GetCurrentRankGroupLevel();

Expand Down

0 comments on commit 37df98a

Please sign in to comment.