diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index ef891e32c0f..ac8a37b4e59 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -250,7 +250,7 @@ void ApiRegisterTensorHook(const std::shared_ptr& self, const AutogradMe Maybe CheckConsistentTensorMeta(const one::Tensor& tensor, int64_t seconds) { const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(tensor)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); JUST(ctx->Check()); return Maybe::Ok(); } @@ -412,9 +412,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes) .def_property_readonly("device", &TensorGetDevice) .def_property_readonly("data", &Tensor::data) - .def("rpc_token", + .def("consistent_id", [](const one::Tensor& tensor) -> int64_t { - return static_cast(tensor.rpc_token().GetOrThrow()); + return static_cast(tensor.transport_token().GetOrThrow()); }) .def("check_meta_consistency", [](const one::Tensor& tensor) { diff --git a/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp b/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp index 72ea1da4020..22451b2e1d0 100644 --- a/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp +++ b/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp @@ -29,9 +29,9 @@ namespace oneflow { namespace { -Maybe InitConsistentRpcTokenScope(const std::string& thread_tag, - int64_t thread_consistent_uid, - Symbol rank_group) { +Maybe InitConsistentTransportTokenScope(const std::string& thread_tag, + int64_t thread_consistent_uid, + Symbol rank_group) { JUST(SetThisThreadConsistentUniqueId(thread_consistent_uid, thread_tag)); static thread_local const auto& init_rank_group_scope = JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group)); @@ -40,21 +40,21 @@ Maybe InitConsistentRpcTokenScope(const std::string& thread_tag, return Maybe::Ok(); } -Maybe InitConsistentRpcTokenScope(const std::string& thread_tag, - int64_t thread_consistent_uid) { +Maybe InitConsistentTransportTokenScope(const std::string& thread_tag, + int64_t thread_consistent_uid) { const auto& rank_group = JUST(RankGroup::DefaultRankGroup()); - JUST(InitConsistentRpcTokenScope(thread_tag, thread_consistent_uid, rank_group)); + JUST(InitConsistentTransportTokenScope(thread_tag, thread_consistent_uid, rank_group)); return Maybe::Ok(); } -void ApiInitDefaultConsistentRpcTokenScope() { - return InitConsistentRpcTokenScope("main", 0).GetOrThrow(); +void ApiInitDefaultConsistentTransportTokenScope() { + return InitConsistentTransportTokenScope("main", 0).GetOrThrow(); } } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { - m.def("InitDefaultConsistentRpcTokenScope", &ApiInitDefaultConsistentRpcTokenScope); + m.def("InitDefaultConsistentTransportTokenScope", &ApiInitDefaultConsistentTransportTokenScope); } } // namespace oneflow diff --git a/oneflow/api/python/rpc/rank_group.cpp b/oneflow/api/python/rpc/rank_group.cpp index e7414e45c64..cf3df692107 100644 --- a/oneflow/api/python/rpc/rank_group.cpp +++ b/oneflow/api/python/rpc/rank_group.cpp @@ -30,8 +30,8 @@ namespace { Maybe CheckCurrentRankGroupConsistency(int64_t seconds) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - const auto& ctx = JUST(CheckRpcToken(rank_group)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); + const auto& ctx = JUST(CheckTransportToken(rank_group)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index e0bbec01d22..8f2eacac279 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -81,8 +81,8 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, for (int i = 0; i < outputs->size(); ++i) { const auto& tensor_impl = JUST(EagerConsistentTensorImpl::New(output_tensor_metas.at(i), device, parallel_id, false, false)); - const auto& rpc_token = JUST(RpcToken::NewMetaRpcToken()); - JUST(tensor_impl->set_rpc_token(rpc_token)); + const auto& transport_token = JUST(TransportToken::NewMetaTransportToken()); + JUST(tensor_impl->set_transport_token(transport_token)); outputs->at(i).reset(new ConsistentTensor(tensor_impl)); } // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx. diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index a37614c900f..5987011620c 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -259,8 +259,8 @@ Maybe EagerMirroredInterpreter::ApplyImpl(const CastToConsistentOpExpr& op const auto& consistent_tensor_impl = JUST(EagerConsistentTensorImpl::New( SymbolOf(tensor_meta), device, parallel_id, input_mirrored_tensor->requires_grad(), !input_mirrored_tensor->requires_grad())); - const auto& rpc_token = JUST(RpcToken::NewMetaRpcToken()); - JUST(consistent_tensor_impl->set_rpc_token(rpc_token)); + const auto& transport_token = JUST(TransportToken::NewMetaTransportToken()); + JUST(consistent_tensor_impl->set_transport_token(transport_token)); consistent_tensor = std::make_shared(consistent_tensor_impl); const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(*consistent_tensor)); if (parallel_id.has_value()) { @@ -269,7 +269,7 @@ Maybe EagerMirroredInterpreter::ApplyImpl(const CastToConsistentOpExpr& op consistent_tensor_impl->reset_cur_rank_phy_tensor( std::dynamic_pointer_cast(synced_tensor)); } - JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, RpcUtil::TimeoutSeconds())); + JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, TransportUtil::TimeoutSeconds())); JUST(ctx->Check()); } outputs->at(0) = consistent_tensor; diff --git a/oneflow/core/framework/rank_group_rpc_util.cpp b/oneflow/core/framework/rank_group_rpc_util.cpp index 22821206600..1786b26ab22 100644 --- a/oneflow/core/framework/rank_group_rpc_util.cpp +++ b/oneflow/core/framework/rank_group_rpc_util.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include #include "oneflow/core/framework/rank_group_rpc_util.h" -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/job/rank_group.h" @@ -27,9 +27,9 @@ limitations under the License. namespace oneflow { -Maybe CheckRpcToken(Symbol rank_group) { - const auto& rpc_token = - JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdCheckRankGroupConsistency)); +Maybe CheckTransportToken(Symbol rank_group) { + const auto& transport_token = + JUST(TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdCheckRankGroupConsistency)); const auto& PrepareBuffer = [](void** buffer, std::size_t* size, std::function* Callback) -> Maybe { const auto& placeholder = std::make_shared(); @@ -38,9 +38,10 @@ Maybe CheckRpcToken(Symbol rank_group) { *Callback = [placeholder]() {}; return Maybe::Ok(); }; - const auto& ctx = std::make_shared(rpc_token, PrepareBuffer, PrepareBuffer); - JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, ctx.get())); - JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, ctx.get())); + const auto& ctx = + std::make_shared(transport_token, PrepareBuffer, PrepareBuffer); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); return ctx; } diff --git a/oneflow/core/framework/rank_group_rpc_util.h b/oneflow/core/framework/rank_group_rpc_util.h index 39f75c3ac79..24305189766 100644 --- a/oneflow/core/framework/rank_group_rpc_util.h +++ b/oneflow/core/framework/rank_group_rpc_util.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ -#include "oneflow/core/framework/rpc_token.h" -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_token.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/symbol.h" #include "oneflow/core/job/rank_group.h" namespace oneflow { -Maybe CheckRpcToken(Symbol rank_group); +Maybe CheckTransportToken(Symbol rank_group); Maybe GetCurrentRankGroupLevel(); diff --git a/oneflow/core/framework/rpc_token.cpp b/oneflow/core/framework/rpc_token.cpp deleted file mode 100644 index 6798c660ba1..00000000000 --- a/oneflow/core/framework/rpc_token.cpp +++ /dev/null @@ -1,292 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include -#include "oneflow/core/framework/rpc_token.h" -#include "oneflow/core/common/data_type.h" -#include "oneflow/core/common/data_type.h" -#include "oneflow/core/thread/consistent_unique_id.h" -#include "oneflow/core/framework/rank_group_rpc_util.h" - -namespace oneflow { - -namespace { - -class DataRpcTokenView final { - public: - static Maybe MutCast(RpcToken* rpc_token) { - CHECK_EQ_OR_RETURN(rpc_token->type(), kDataRpcTokenType); - return reinterpret_cast(rpc_token); - } - - void set_data_seq_id(int64_t seq_id) { data_seq_id_ = seq_id; } - - private: - uint16_t src_rank_; - uint16_t dst_rank_; - uint32_t type_ : 2; // RpcTokenType - uint32_t data_seq_id_ : 30; -}; -static_assert(sizeof(DataRpcTokenView) == sizeof(uint64_t), ""); - -class MetaRpcTokenView final { - public: - int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } - int64_t rank_group_level() const { return rank_group_level_; } - - static Maybe MutCast(RpcToken* rpc_token) { - CHECK_EQ_OR_RETURN(rpc_token->type(), kMetaRpcTokenType); - return reinterpret_cast(rpc_token); - } - - static Maybe Cast(const RpcToken* rpc_token) { - CHECK_EQ_OR_RETURN(rpc_token->type(), kMetaRpcTokenType); - return reinterpret_cast(rpc_token); - } - - Maybe set_thread_consistent_unique_id(int8_t val) { - CHECK_GE_OR_RETURN(val, 0); - CHECK_LT_OR_RETURN(val, 1 << kRpcTokenThreadConsistentUIdBit); - thread_consistent_unique_id_ = val; - return Maybe::Ok(); - } - - Maybe set_rank_group_level(int32_t val) { - CHECK_GE_OR_RETURN(val, 0); - CHECK_LT_OR_RETURN(val, 1 << kRpcTokenRankGroupLevelBit); - rank_group_level_ = val; - return Maybe::Ok(); - } - - MetaRpcTokenView& operator++() { - ++low_meta_seq_id_; - if (low_meta_seq_id_ == 0) { ++high_meta_seq_id_; } - return *this; - } - - private: - uint16_t src_rank_; - uint16_t dst_rank_; - uint8_t type_ : 2; // RpcTokenType - uint8_t thread_consistent_unique_id_ : kRpcTokenThreadConsistentUIdBit; - uint8_t rank_group_level_ : kRpcTokenRankGroupLevelBit; - uint8_t high_meta_seq_id_; - uint16_t low_meta_seq_id_; -}; -static_assert(sizeof(MetaRpcTokenView) == sizeof(uint64_t), ""); - -class CtrlRpcTokenView final { - public: - int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } - int64_t rank_group_level() const { return rank_group_level_; } - - static Maybe MutCast(RpcToken* rpc_token) { - CHECK_EQ_OR_RETURN(rpc_token->type(), kCtrlRpcTokenType); - return reinterpret_cast(rpc_token); - } - - static Maybe Cast(const RpcToken* rpc_token) { - CHECK_EQ_OR_RETURN(rpc_token->type(), kCtrlRpcTokenType); - return reinterpret_cast(rpc_token); - } - - Maybe set_thread_consistent_unique_id(int8_t val) { - CHECK_GE_OR_RETURN(val, 0); - CHECK_LT_OR_RETURN(val, 1 << kRpcTokenThreadConsistentUIdBit); - thread_consistent_unique_id_ = val; - return Maybe::Ok(); - } - Maybe set_rank_group_level(int32_t val) { - CHECK_GE_OR_RETURN(val, 0); - CHECK_LT_OR_RETURN(val, 1 << kRpcTokenRankGroupLevelBit); - rank_group_level_ = val; - return Maybe::Ok(); - } - - RankGroupRpcCmd cmd() const { return static_cast(cmd_); } - - void set_cmd(RankGroupRpcCmd cmd) { - static_assert(kSizeOfRankGroupRpcCmd < (1 << 8), ""); - cmd_ = static_cast(cmd); - } - - void set_ctrl_seq_id(int32_t val) { ctrl_seq_id_ = val; } - - private: - uint16_t src_rank_; - uint16_t dst_rank_; - uint8_t type_ : 2; // RpcTokenType - uint8_t thread_consistent_unique_id_ : kRpcTokenThreadConsistentUIdBit; - uint8_t rank_group_level_ : kRpcTokenRankGroupLevelBit; - uint8_t cmd_; - uint16_t ctrl_seq_id_; -}; -static_assert(sizeof(CtrlRpcTokenView) == sizeof(uint64_t), ""); - -} // namespace - -RpcToken::RpcToken(RpcTokenType type) { - static_assert(sizeof(RpcToken) == sizeof(int64_t), ""); - *reinterpret_cast(this) = 0; - type_ = type; -} - -/*static*/ RpcToken RpcToken::NewDataRpcToken() { - static auto* seq_id = new std::atomic(); - RpcToken rpc_token(kDataRpcTokenType); - CHECK_JUST(DataRpcTokenView::MutCast(&rpc_token))->set_data_seq_id(++*seq_id); - return rpc_token; -} - -/*static*/ Maybe RpcToken::NewMetaRpcToken() { - int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); - int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); - static const int kLimit = 128; - CHECK_GE_OR_RETURN(rank_group_level, 0); - CHECK_LT_OR_RETURN(rank_group_level, kLimit); - static thread_local std::array, kLimit> rpc_token_stack; - auto* current_rpc_token = &rpc_token_stack[rank_group_level]; - if (!*current_rpc_token) { - const auto& init = JUST(NewMetaRpcToken(thread_consistent_unique_id, rank_group_level)); - current_rpc_token->reset(new RpcToken(init)); - } - return ++**current_rpc_token; -} - -namespace { - -Maybe ThreadLocalMutLock4CtrlRpcToken(int32_t thread_consistent_unique_id, - int32_t rank_group_level, RankGroupRpcCmd cmd) { - CHECK_EQ_OR_RETURN(thread_consistent_unique_id, JUST(GetThisThreadConsistentUniqueId())); - static const int kRpcTokenRankGroupLevelLimit = (1 << kRpcTokenRankGroupLevelBit); - CHECK_LT_OR_RETURN(rank_group_level, kRpcTokenRankGroupLevelLimit); - static thread_local std::array, - kRpcTokenRankGroupLevelLimit> - rpc_token_lock; - return &rpc_token_lock[rank_group_level][cmd]; -} - -} // namespace - -/*static*/ Maybe RpcToken::AcquireCtrlRpcToken(RankGroupRpcCmd cmd) { - int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); - int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); - auto* lock = - JUST(ThreadLocalMutLock4CtrlRpcToken(thread_consistent_unique_id, rank_group_level, cmd)); - CHECK_OR_RETURN(!*lock); - static const int kRpcTokenRankGroupLevelLimit = (1 << kRpcTokenRankGroupLevelBit); - static thread_local std::array, kSizeOfRankGroupRpcCmd>, - kRpcTokenRankGroupLevelLimit> - rpc_token_stack; - CHECK_GE_OR_RETURN(rank_group_level, 0); - CHECK_LT_OR_RETURN(rank_group_level, kRpcTokenRankGroupLevelLimit); - CHECK_GE_OR_RETURN(static_cast(cmd), 0); - CHECK_LT_OR_RETURN(static_cast(cmd), kSizeOfRankGroupRpcCmd); - auto* current_rpc_token = &rpc_token_stack[rank_group_level][cmd]; - if (!*current_rpc_token) { - const auto& init = JUST(NewCtrlRpcToken(cmd, thread_consistent_unique_id, rank_group_level)); - current_rpc_token->reset(new RpcToken(init)); - } - *lock = true; - return **current_rpc_token; -} - -Maybe RpcToken::ReleaseCtrlRpcToken() const { - auto* lock = JUST(ThreadLocalMutLock4CtrlRpcToken(JUST(thread_consistent_unique_id()), - JUST(rank_group_level()), JUST(cmd()))); - CHECK_OR_RETURN(*lock); - *lock = false; - return Maybe::Ok(); -} - -Maybe RpcToken::thread_consistent_unique_id() const { - if (type() == kMetaRpcTokenType) { - return JUST(MetaRpcTokenView::Cast(this))->thread_consistent_unique_id(); - } else if (type() == kCtrlRpcTokenType) { - return JUST(CtrlRpcTokenView::Cast(this))->thread_consistent_unique_id(); - } else { - UNIMPLEMENTED_THEN_RETURN(); - } - UNIMPLEMENTED_THEN_RETURN(); -} - -Maybe RpcToken::rank_group_level() const { - if (type() == kMetaRpcTokenType) { - return JUST(MetaRpcTokenView::Cast(this))->rank_group_level(); - } else if (type() == kCtrlRpcTokenType) { - return JUST(CtrlRpcTokenView::Cast(this))->rank_group_level(); - } else { - UNIMPLEMENTED_THEN_RETURN(); - } - UNIMPLEMENTED_THEN_RETURN(); -} - -Maybe RpcToken::cmd() const { return JUST(CtrlRpcTokenView::Cast(this))->cmd(); } - -Maybe RpcToken::set_src_rank(int64_t src_rank) { - CHECK_GE_OR_RETURN(src_rank, 0); - CHECK_LT_OR_RETURN(src_rank, GetMaxVal()); - src_rank_ = src_rank; - return Maybe::Ok(); -} - -Maybe RpcToken::set_dst_rank(int64_t dst_rank) { - CHECK_GE_OR_RETURN(dst_rank, 0); - CHECK_LT_OR_RETURN(dst_rank, GetMaxVal()); - dst_rank_ = dst_rank; - return Maybe::Ok(); -} - -RpcToken::operator uint64_t() const { - static_assert(sizeof(RpcToken) == sizeof(uint64_t), ""); - return *reinterpret_cast(this); -} - -RpcToken& RpcToken::operator++() { - RpcTokenType rpc_token_type = type(); - if (rpc_token_type == kDataRpcTokenType) { - UNIMPLEMENTED(); - } else if (rpc_token_type == kMetaRpcTokenType) { - ++*CHECK_JUST(MetaRpcTokenView::MutCast(this)); - } else if (rpc_token_type == kCtrlRpcTokenType) { - UNIMPLEMENTED(); - } else { - UNIMPLEMENTED(); - } - return *this; -} - -/*static*/ Maybe RpcToken::NewMetaRpcToken(int32_t thread_consistent_unique_id, - int32_t rank_group_level) { - RpcToken rpc_token(kMetaRpcTokenType); - auto* view = JUST(MetaRpcTokenView::MutCast(&rpc_token)); - JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); - JUST(view->set_rank_group_level(rank_group_level)); - return rpc_token; -} - -/*static*/ Maybe RpcToken::NewCtrlRpcToken(RankGroupRpcCmd cmd, - int32_t thread_consistent_unique_id, - int32_t rank_group_level) { - RpcToken rpc_token(kCtrlRpcTokenType); - auto* view = JUST(CtrlRpcTokenView::MutCast(&rpc_token)); - JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); - JUST(view->set_rank_group_level(rank_group_level)); - view->set_cmd(cmd); - view->set_ctrl_seq_id(0); - return rpc_token; -} - -} // namespace oneflow diff --git a/oneflow/core/framework/rpc_token.h b/oneflow/core/framework/rpc_token.h deleted file mode 100644 index e54588c1ee2..00000000000 --- a/oneflow/core/framework/rpc_token.h +++ /dev/null @@ -1,107 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ -#define ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ - -#include "oneflow/core/common/type_traits.h" -#include "oneflow/core/common/maybe.h" - -namespace oneflow { - -const static int kRpcTokenTypeBit = 2; -const static int kRpcTokenThreadConsistentUIdBit = 3; -const static int kRpcTokenRankGroupLevelBit = 3; - -enum RpcTokenType { - // Begin - kDataRpcTokenType = 0, // e.g. for tensor data transportation - kMetaRpcTokenType, // e.g. for tensor meta checking - kCtrlRpcTokenType, // e.g. for rank_group or thread checking. see RankGroupRpcCmd - kExtendedRpcTokenType, // for compatibility - // End - kRpcTokenTypeSize, -}; - -static_assert(kRpcTokenTypeSize <= (1 << kRpcTokenTypeBit), ""); - -enum RankGroupRpcCmd { - // Begin - kRankGroupRpcCmdInvalid = 0, - kRankGroupRpcCmdSyncSymbolParallelDesc, - kRankGroupRpcCmdSyncSymbolParallelDistribution, - kRankGroupRpcCmdSyncSymbolConsistentTensorMeta, - kRankGroupRpcCmdCheckRankGroupConsistency, - kRankGroupRpcCmdCheckTensorConsistency, - kRankGroupRpcCmdAll2AllSyncShape, - // End - kSizeOfRankGroupRpcCmd -}; - -class RpcToken; - -template<> -struct IsScalarType final { - static const bool value = true; -}; - -class RpcToken final { - public: - RpcToken(const RpcToken&) = default; - RpcToken(RpcToken&) = default; - ~RpcToken() = default; - - static RpcToken NewDataRpcToken(); - static Maybe NewMetaRpcToken(); - static Maybe AcquireCtrlRpcToken(RankGroupRpcCmd cmd); - Maybe ReleaseCtrlRpcToken() const; - - static constexpr size_t MaxNumberOfThreadConsistentUId() { - return (1 << kRpcTokenThreadConsistentUIdBit); - } - - // Getters - int64_t src_rank() const { return src_rank_; } - int64_t dst_rank() const { return dst_rank_; } - RpcTokenType type() const { return static_cast(type_); } - Maybe thread_consistent_unique_id() const; - Maybe rank_group_level() const; - Maybe cmd() const; - - // Setters - Maybe set_src_rank(int64_t src_rank); - Maybe set_dst_rank(int64_t dst_rank); - - operator uint64_t() const; - RpcToken& operator++(); - - private: - explicit RpcToken(RpcTokenType type); - - static Maybe NewMetaRpcToken(int32_t thread_consistent_unique_id, - int32_t rank_group_level); - static Maybe NewCtrlRpcToken(RankGroupRpcCmd cmd, int32_t thread_consistent_unique_id, - int32_t rank_group_level); - - uint16_t src_rank_; - uint16_t dst_rank_; - uint32_t type_ : 2; // RpcTokenType - uint32_t opaque_ids_ : 30; -}; -static_assert(sizeof(RpcToken) == sizeof(uint64_t), ""); - -} // namespace oneflow - -#endif // ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ diff --git a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp index b808484d8bc..d3887fcac14 100644 --- a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp +++ b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp @@ -68,11 +68,11 @@ struct FlatConsistentTensorMeta final { Maybe SyncSymbolConsistentTensorMeta( uint64_t symbol_id, Symbol consistent_tensor_meta) { - const auto& rpc_token = - JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolConsistentTensorMeta)); + const auto& transport_token = JUST( + TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdSyncSymbolConsistentTensorMeta)); const auto& recv_buffer = std::make_shared(); - NaiveAsyncRpcCtx ctx( - rpc_token, + NaiveAsyncTransportCtx ctx( + transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { const auto& send_buffer = JUST(FlatConsistentTensorMeta::New(symbol_id, consistent_tensor_meta)); @@ -88,9 +88,9 @@ Maybe SyncSymbolConsistentTensorMeta( return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(ctx, RpcUtil::TimeoutSeconds())); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); JUST(recv_buffer->Check(symbol_id, consistent_tensor_meta)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h index be61e3dd0b9..16b4f998824 100644 --- a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h +++ b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h @@ -18,8 +18,8 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" -#include "oneflow/core/framework/rpc_util.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/core/framework/transport_token.h" namespace oneflow { diff --git a/oneflow/core/framework/sync_symbol_parallel_desc.cpp b/oneflow/core/framework/sync_symbol_parallel_desc.cpp index 05f8a57154c..962454ac29c 100644 --- a/oneflow/core/framework/sync_symbol_parallel_desc.cpp +++ b/oneflow/core/framework/sync_symbol_parallel_desc.cpp @@ -68,11 +68,11 @@ struct FlatParallelConf { } // namespace Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol parallel_desc) { - const auto& rpc_token = - JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolParallelDesc)); + const auto& transport_token = + JUST(TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdSyncSymbolParallelDesc)); const auto& recv_buffer = std::make_shared(); - NaiveAsyncRpcCtx ctx( - rpc_token, + NaiveAsyncTransportCtx ctx( + transport_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { const auto& send_buffer = JUST(FlatParallelConf::New(symbol_id, parallel_desc)); *buffer = send_buffer.get(); @@ -87,9 +87,9 @@ Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol para return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(ctx, RpcUtil::TimeoutSeconds())); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); JUST(recv_buffer->Check(symbol_id, parallel_desc)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/sync_symbol_parallel_desc.h b/oneflow/core/framework/sync_symbol_parallel_desc.h index 02309e7b30c..7b6a6fc97ed 100644 --- a/oneflow/core/framework/sync_symbol_parallel_desc.h +++ b/oneflow/core/framework/sync_symbol_parallel_desc.h @@ -18,8 +18,8 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" -#include "oneflow/core/framework/rpc_util.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/core/framework/transport_token.h" namespace oneflow { diff --git a/oneflow/core/framework/sync_symbol_parallel_distribution.cpp b/oneflow/core/framework/sync_symbol_parallel_distribution.cpp index 7268991ef4c..0b9d1850c9c 100644 --- a/oneflow/core/framework/sync_symbol_parallel_distribution.cpp +++ b/oneflow/core/framework/sync_symbol_parallel_distribution.cpp @@ -94,15 +94,16 @@ FLAT_MSG_DEFINE_OPTIONAL(size_t, size); FLAT_MSG_DEFINE_REPEATED(FlatSbpParallel, sbp_parallel, SHAPE_MAX_AXIS_SIZE); FLAT_MSG_END(FlatParallelDistribution); -class FlatParallelDistributionAsyncRpcCtx : public AsyncRpcCtx { +class FlatParallelDistributionAsyncTransportCtx : public AsyncTransportCtx { public: - FlatParallelDistributionAsyncRpcCtx(const RpcToken& rpc_token, uint64_t symbol_id, - Symbol parallel_distribution) - : AsyncRpcCtx(rpc_token), + FlatParallelDistributionAsyncTransportCtx(const TransportToken& transport_token, + uint64_t symbol_id, + Symbol parallel_distribution) + : AsyncTransportCtx(transport_token), symbol_id_(symbol_id), parallel_distribution_(parallel_distribution) {} - ~FlatParallelDistributionAsyncRpcCtx() override {} + ~FlatParallelDistributionAsyncTransportCtx() override {} Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { @@ -143,12 +144,12 @@ namespace {} Maybe SyncSymbolParallelDistribution(uint64_t symbol_id, Symbol symbol) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - const auto& rpc_token = - JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolParallelDistribution)); - FlatParallelDistributionAsyncRpcCtx ctx(rpc_token, symbol_id, symbol); - JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(ctx, RpcUtil::TimeoutSeconds())); + const auto& transport_token = JUST( + TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdSyncSymbolParallelDistribution)); + FlatParallelDistributionAsyncTransportCtx ctx(transport_token, symbol_id, symbol); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); JUST(ctx.Check()); return Maybe::Ok(); } diff --git a/oneflow/core/framework/sync_symbol_parallel_distribution.h b/oneflow/core/framework/sync_symbol_parallel_distribution.h index d14cce93b87..2990b55da25 100644 --- a/oneflow/core/framework/sync_symbol_parallel_distribution.h +++ b/oneflow/core/framework/sync_symbol_parallel_distribution.h @@ -18,8 +18,8 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" -#include "oneflow/core/framework/rpc_util.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/core/framework/transport_token.h" namespace oneflow { diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index 4c3a8244680..d2da1e67932 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -22,7 +22,7 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/framework/tensor_impl.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_token.h" #include "oneflow/core/common/error.h" namespace oneflow { @@ -50,7 +50,7 @@ class Tensor { virtual const std::shared_ptr& shape() const = 0; virtual DataType dtype() const = 0; - virtual Maybe rpc_token() const = 0; + virtual Maybe transport_token() const = 0; virtual Maybe> parallel_distribution() const = 0; virtual Maybe> parallel_desc() const = 0; virtual Maybe> device() const = 0; @@ -123,7 +123,7 @@ class StaticZerosTensor final : public Tensor { // Getters const std::shared_ptr& shape() const { return shape_; } DataType dtype() const { return dtype_; } - Maybe rpc_token() const { OF_UNIMPLEMENTED(); } + Maybe transport_token() const { OF_UNIMPLEMENTED(); } Maybe> parallel_distribution() const { OF_UNIMPLEMENTED(); } Maybe> parallel_desc() const { OF_UNIMPLEMENTED(); } Maybe> device() const { return device_; } @@ -296,7 +296,7 @@ class Parameter final : public TensorIf { consumer_parallel_distribution_constraint() const override { return tensor_->consumer_parallel_distribution_constraint(); } - Maybe rpc_token() const override { return tensor_->rpc_token(); } + Maybe transport_token() const override { return tensor_->transport_token(); } Maybe cur_rank_phy_tensor() const override { return tensor_->cur_rank_phy_tensor(); } @@ -357,7 +357,7 @@ class MirroredTensor final : public TensorIf, // Getters const std::shared_ptr& shape() const override { return impl_->shape(); } DataType dtype() const override { return impl_->dtype(); } - Maybe rpc_token() const override { OF_UNIMPLEMENTED(); } + Maybe transport_token() const override { OF_UNIMPLEMENTED(); } Maybe> parallel_distribution() const override { OF_UNIMPLEMENTED(); } @@ -438,7 +438,7 @@ class ConsistentTensor final : public TensorIf { // Getters const std::shared_ptr& shape() const override { return impl_->shape(); } DataType dtype() const override { return impl_->dtype(); } - Maybe rpc_token() const override { return impl_->rpc_token(); } + Maybe transport_token() const override { return impl_->transport_token(); } Maybe> parallel_distribution() const override { return impl_->parallel_distribution(); } diff --git a/oneflow/core/framework/tensor_impl.h b/oneflow/core/framework/tensor_impl.h index 5b97a6d0901..e5883fc10bf 100644 --- a/oneflow/core/framework/tensor_impl.h +++ b/oneflow/core/framework/tensor_impl.h @@ -25,7 +25,7 @@ limitations under the License. #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/framework/tensor_desc.h" #include "oneflow/core/framework/tensor_meta.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_token.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/common/symbol.h" @@ -158,11 +158,11 @@ class ConsistentTensorImpl : public TensorImpl { return nullptr; } - const Maybe rpc_token() const { return rpc_token_; } + const Maybe transport_token() const { return transport_token_; } - Maybe set_rpc_token(const RpcToken& rpc_token) { - CHECK_OR_RETURN(!rpc_token_.IsOk()) << "rpc_token_ is initiliazed"; - rpc_token_ = rpc_token; + Maybe set_transport_token(const TransportToken& transport_token) { + CHECK_OR_RETURN(!transport_token_.IsOk()) << "transport_token_ is initiliazed"; + transport_token_ = transport_token; return Maybe::Ok(); } @@ -171,11 +171,11 @@ class ConsistentTensorImpl : public TensorImpl { : TensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta), consumer_parallel_distribution_constraint_(), - rpc_token_(Error::ValueError("invalid rpc token")) {} + transport_token_(Error::ValueError("invalid rpc token")) {} Symbol tensor_meta_; Optional> consumer_parallel_distribution_constraint_; - Maybe rpc_token_; + Maybe transport_token_; }; class LazyMirroredTensorImpl final : public MirroredTensorImpl { diff --git a/oneflow/core/framework/tensor_rpc_util.cpp b/oneflow/core/framework/tensor_rpc_util.cpp index eccd2f44618..99e1d15817f 100644 --- a/oneflow/core/framework/tensor_rpc_util.cpp +++ b/oneflow/core/framework/tensor_rpc_util.cpp @@ -38,16 +38,16 @@ FLAT_MSG_BEGIN(FlatTensorConsistency); OF_PUBLIC static Maybe New( Symbol tensor_meta, const Optional> consumer_parallel_distribution_constraint, - const RpcToken& tensor_rpc_token) { + const TransportToken& tensor_transport_token) { const auto& consistency = std::make_shared(); consistency->clear(); - JUST(consistency->Init(tensor_meta, consumer_parallel_distribution_constraint, tensor_rpc_token)); + JUST(consistency->Init(tensor_meta, consumer_parallel_distribution_constraint, tensor_transport_token)); return consistency; } OF_PUBLIC Maybe Check(Symbol tensor_meta, const Optional> consumer_parallel_distribution_constraint, - const RpcToken& tensor_rpc_token) { + const TransportToken& tensor_transport_token) { const auto& this_synced_tensor_meta = JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( this->synced_tensor_meta_symbol_id())); @@ -61,13 +61,13 @@ FLAT_MSG_BEGIN(FlatTensorConsistency); const auto& this_rank_constaint = JUST(consumer_parallel_distribution_constraint.value()); CHECK_OR_RETURN(this_rank_constaint == that_rank_constaint); } - CHECK_EQ_OR_RETURN(this->tensor_rpc_token(), tensor_rpc_token); + CHECK_EQ_OR_RETURN(this->tensor_transport_token(), tensor_transport_token); return Maybe::Ok(); } OF_PRIVATE Maybe Init(Symbol tensor_meta, const Optional> consumer_parallel_distribution_constraint, - const RpcToken& tensor_rpc_token) { + const TransportToken& tensor_transport_token) { this->set_synced_tensor_meta_symbol_id(JUST(SyncedSymbolMap::FindOrSync( tensor_meta, &SyncSymbolConsistentTensorMeta))); if (consumer_parallel_distribution_constraint.has_value()) { @@ -78,29 +78,29 @@ FLAT_MSG_BEGIN(FlatTensorConsistency); } else { this->clear_consumer_parallel_distribution_constraint_symbol_id(); } - this->set_tensor_rpc_token(static_cast(tensor_rpc_token)); + this->set_tensor_transport_token(static_cast(tensor_transport_token)); return Maybe::Ok(); } FLAT_MSG_DEFINE_OPTIONAL(uint64_t, synced_tensor_meta_symbol_id); FLAT_MSG_DEFINE_OPTIONAL(uint64_t, consumer_parallel_distribution_constraint_symbol_id); - FLAT_MSG_DEFINE_OPTIONAL(uint64_t, tensor_rpc_token); + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, tensor_transport_token); FLAT_MSG_END(FlatTensorConsistency); // clang-format off -CheckConsistencyAsyncRpcCtx::~CheckConsistencyAsyncRpcCtx() {} +CheckConsistencyAsyncTransportCtx::~CheckConsistencyAsyncTransportCtx() {} -Maybe CheckConsistencyAsyncRpcCtx::PrepareSendBufferAndCallback( +Maybe CheckConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& tensor_consistency = - JUST(FlatTensorConsistency::New(tensor_meta_, consumer_parallel_distribution_constraint_, tensor_rpc_token_)); + JUST(FlatTensorConsistency::New(tensor_meta_, consumer_parallel_distribution_constraint_, tensor_transport_token_)); *buffer = tensor_consistency.get(); *size = sizeof(FlatTensorConsistency); *Callback = [tensor_consistency] {}; return Maybe::Ok(); } -Maybe CheckConsistencyAsyncRpcCtx::PrepareRecvBufferAndCallback( +Maybe CheckConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback( int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { const auto& flat_tensor_consistency = JUST(FlatTensorConsistency::New()); *buffer = flat_tensor_consistency.get(); @@ -110,24 +110,24 @@ Maybe CheckConsistencyAsyncRpcCtx::PrepareRecvBufferAndCallback( return Maybe::Ok(); } -Maybe CheckConsistencyAsyncRpcCtx::Check() const { +Maybe CheckConsistencyAsyncTransportCtx::Check() const { if (!flat_tensor_consistency_) { return Maybe::Ok(); } JUST(flat_tensor_consistency_->Check( - tensor_meta_, consumer_parallel_distribution_constraint_, tensor_rpc_token_)); + tensor_meta_, consumer_parallel_distribution_constraint_, tensor_transport_token_)); return Maybe::Ok(); } -Maybe LaunchTensorMetaConsistencyCheck(const one::Tensor& tensor) { +Maybe LaunchTensorMetaConsistencyCheck(const one::Tensor& tensor) { const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - const auto& rpc_token = - JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdCheckTensorConsistency)); + const auto& transport_token = + JUST(TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdCheckTensorConsistency)); const auto& tensor_meta = JUST(tensor.consistent_tensor_meta()); const auto& constaint = JUST(tensor.consumer_parallel_distribution_constraint()); - const RpcToken& tensor_rpc_token = JUST(tensor.rpc_token()); - const auto& ctx = std::make_shared( - rpc_token, tensor_meta, constaint, tensor_rpc_token); - JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, ctx.get())); - JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, ctx.get())); + const TransportToken& tensor_transport_token = JUST(tensor.transport_token()); + const auto& ctx = std::make_shared( + transport_token, tensor_meta, constaint, tensor_transport_token); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); return ctx; } diff --git a/oneflow/core/framework/tensor_rpc_util.h b/oneflow/core/framework/tensor_rpc_util.h index 2e48678972f..f29c20552d1 100644 --- a/oneflow/core/framework/tensor_rpc_util.h +++ b/oneflow/core/framework/tensor_rpc_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/common/optional.h" @@ -24,18 +24,18 @@ namespace oneflow { class FlatTensorConsistency; -class CheckConsistencyAsyncRpcCtx : public AsyncRpcCtx { +class CheckConsistencyAsyncTransportCtx : public AsyncTransportCtx { public: - CheckConsistencyAsyncRpcCtx( - const RpcToken& rpc_token, Symbol tensor_meta, + CheckConsistencyAsyncTransportCtx( + const TransportToken& transport_token, Symbol tensor_meta, const Optional>& consumer_parallel_distribution_constraint, - const RpcToken& tensor_rpc_token) - : AsyncRpcCtx(rpc_token), + const TransportToken& tensor_transport_token) + : AsyncTransportCtx(transport_token), tensor_meta_(tensor_meta), consumer_parallel_distribution_constraint_(consumer_parallel_distribution_constraint), - tensor_rpc_token_(tensor_rpc_token) {} + tensor_transport_token_(tensor_transport_token) {} - ~CheckConsistencyAsyncRpcCtx() override; + ~CheckConsistencyAsyncTransportCtx() override; Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override; @@ -48,11 +48,12 @@ class CheckConsistencyAsyncRpcCtx : public AsyncRpcCtx { private: Symbol tensor_meta_; Optional> consumer_parallel_distribution_constraint_; - RpcToken tensor_rpc_token_; + TransportToken tensor_transport_token_; std::shared_ptr flat_tensor_consistency_; }; -Maybe LaunchTensorMetaConsistencyCheck(const one::Tensor& tensor); +Maybe LaunchTensorMetaConsistencyCheck( + const one::Tensor& tensor); } // namespace oneflow diff --git a/oneflow/core/framework/transport_token.cpp b/oneflow/core/framework/transport_token.cpp new file mode 100644 index 00000000000..2defbb242c4 --- /dev/null +++ b/oneflow/core/framework/transport_token.cpp @@ -0,0 +1,295 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/framework/transport_token.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" + +namespace oneflow { + +namespace { + +class DataTransportTokenView final { + public: + static Maybe MutCast(TransportToken* transport_token) { + CHECK_EQ_OR_RETURN(transport_token->type(), kDataTransportTokenType); + return reinterpret_cast(transport_token); + } + + void set_data_seq_id(int64_t seq_id) { data_seq_id_ = seq_id; } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint32_t type_ : 2; // TransportTokenType + uint32_t data_seq_id_ : 30; +}; +static_assert(sizeof(DataTransportTokenView) == sizeof(uint64_t), ""); + +class MetaTransportTokenView final { + public: + int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } + int64_t rank_group_level() const { return rank_group_level_; } + + static Maybe MutCast(TransportToken* transport_token) { + CHECK_EQ_OR_RETURN(transport_token->type(), kMetaTransportTokenType); + return reinterpret_cast(transport_token); + } + + static Maybe Cast(const TransportToken* transport_token) { + CHECK_EQ_OR_RETURN(transport_token->type(), kMetaTransportTokenType); + return reinterpret_cast(transport_token); + } + + Maybe set_thread_consistent_unique_id(int8_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kTransportTokenThreadConsistentUIdBit); + thread_consistent_unique_id_ = val; + return Maybe::Ok(); + } + + Maybe set_rank_group_level(int32_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kTransportTokenRankGroupLevelBit); + rank_group_level_ = val; + return Maybe::Ok(); + } + + MetaTransportTokenView& operator++() { + ++low_meta_seq_id_; + if (low_meta_seq_id_ == 0) { ++high_meta_seq_id_; } + return *this; + } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint8_t type_ : 2; // TransportTokenType + uint8_t thread_consistent_unique_id_ : kTransportTokenThreadConsistentUIdBit; + uint8_t rank_group_level_ : kTransportTokenRankGroupLevelBit; + uint8_t high_meta_seq_id_; + uint16_t low_meta_seq_id_; +}; +static_assert(sizeof(MetaTransportTokenView) == sizeof(uint64_t), ""); + +class CtrlTransportTokenView final { + public: + int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } + int64_t rank_group_level() const { return rank_group_level_; } + + static Maybe MutCast(TransportToken* transport_token) { + CHECK_EQ_OR_RETURN(transport_token->type(), kCtrlTransportTokenType); + return reinterpret_cast(transport_token); + } + + static Maybe Cast(const TransportToken* transport_token) { + CHECK_EQ_OR_RETURN(transport_token->type(), kCtrlTransportTokenType); + return reinterpret_cast(transport_token); + } + + Maybe set_thread_consistent_unique_id(int8_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kTransportTokenThreadConsistentUIdBit); + thread_consistent_unique_id_ = val; + return Maybe::Ok(); + } + Maybe set_rank_group_level(int32_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kTransportTokenRankGroupLevelBit); + rank_group_level_ = val; + return Maybe::Ok(); + } + + RankGroupCtrlCmd cmd() const { return static_cast(cmd_); } + + void set_cmd(RankGroupCtrlCmd cmd) { + static_assert(kSizeOfRankGroupCtrlCmd < (1 << 8), ""); + cmd_ = static_cast(cmd); + } + + void set_ctrl_seq_id(int32_t val) { ctrl_seq_id_ = val; } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint8_t type_ : 2; // TransportTokenType + uint8_t thread_consistent_unique_id_ : kTransportTokenThreadConsistentUIdBit; + uint8_t rank_group_level_ : kTransportTokenRankGroupLevelBit; + uint8_t cmd_; + uint16_t ctrl_seq_id_; +}; +static_assert(sizeof(CtrlTransportTokenView) == sizeof(uint64_t), ""); + +} // namespace + +TransportToken::TransportToken(TransportTokenType type) { + static_assert(sizeof(TransportToken) == sizeof(int64_t), ""); + *reinterpret_cast(this) = 0; + type_ = type; +} + +/*static*/ TransportToken TransportToken::NewDataTransportToken() { + static auto* seq_id = new std::atomic(); + TransportToken transport_token(kDataTransportTokenType); + CHECK_JUST(DataTransportTokenView::MutCast(&transport_token))->set_data_seq_id(++*seq_id); + return transport_token; +} + +/*static*/ Maybe TransportToken::NewMetaTransportToken() { + int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); + int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); + static const int kLimit = 128; + CHECK_GE_OR_RETURN(rank_group_level, 0); + CHECK_LT_OR_RETURN(rank_group_level, kLimit); + static thread_local std::array, kLimit> transport_token_stack; + auto* current_transport_token = &transport_token_stack[rank_group_level]; + if (!*current_transport_token) { + const auto& init = JUST(NewMetaTransportToken(thread_consistent_unique_id, rank_group_level)); + current_transport_token->reset(new TransportToken(init)); + } + return ++**current_transport_token; +} + +namespace { + +Maybe ThreadLocalMutLock4CtrlTransportToken(int32_t thread_consistent_unique_id, + int32_t rank_group_level, RankGroupCtrlCmd cmd) { + CHECK_EQ_OR_RETURN(thread_consistent_unique_id, JUST(GetThisThreadConsistentUniqueId())); + static const int kTransportTokenRankGroupLevelLimit = (1 << kTransportTokenRankGroupLevelBit); + CHECK_LT_OR_RETURN(rank_group_level, kTransportTokenRankGroupLevelLimit); + static thread_local std::array, + kTransportTokenRankGroupLevelLimit> + transport_token_lock; + return &transport_token_lock[rank_group_level][cmd]; +} + +} // namespace + +/*static*/ Maybe TransportToken::AcquireCtrlTransportToken(RankGroupCtrlCmd cmd) { + int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); + int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); + auto* lock = JUST( + ThreadLocalMutLock4CtrlTransportToken(thread_consistent_unique_id, rank_group_level, cmd)); + CHECK_OR_RETURN(!*lock); + static const int kTransportTokenRankGroupLevelLimit = (1 << kTransportTokenRankGroupLevelBit); + static thread_local std::array< + std::array, kSizeOfRankGroupCtrlCmd>, + kTransportTokenRankGroupLevelLimit> + transport_token_stack; + CHECK_GE_OR_RETURN(rank_group_level, 0); + CHECK_LT_OR_RETURN(rank_group_level, kTransportTokenRankGroupLevelLimit); + CHECK_GE_OR_RETURN(static_cast(cmd), 0); + CHECK_LT_OR_RETURN(static_cast(cmd), kSizeOfRankGroupCtrlCmd); + auto* current_transport_token = &transport_token_stack[rank_group_level][cmd]; + if (!*current_transport_token) { + const auto& init = + JUST(NewCtrlTransportToken(cmd, thread_consistent_unique_id, rank_group_level)); + current_transport_token->reset(new TransportToken(init)); + } + *lock = true; + return **current_transport_token; +} + +Maybe TransportToken::ReleaseCtrlTransportToken() const { + auto* lock = JUST(ThreadLocalMutLock4CtrlTransportToken(JUST(thread_consistent_unique_id()), + JUST(rank_group_level()), JUST(cmd()))); + CHECK_OR_RETURN(*lock); + *lock = false; + return Maybe::Ok(); +} + +Maybe TransportToken::thread_consistent_unique_id() const { + if (type() == kMetaTransportTokenType) { + return JUST(MetaTransportTokenView::Cast(this))->thread_consistent_unique_id(); + } else if (type() == kCtrlTransportTokenType) { + return JUST(CtrlTransportTokenView::Cast(this))->thread_consistent_unique_id(); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + UNIMPLEMENTED_THEN_RETURN(); +} + +Maybe TransportToken::rank_group_level() const { + if (type() == kMetaTransportTokenType) { + return JUST(MetaTransportTokenView::Cast(this))->rank_group_level(); + } else if (type() == kCtrlTransportTokenType) { + return JUST(CtrlTransportTokenView::Cast(this))->rank_group_level(); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + UNIMPLEMENTED_THEN_RETURN(); +} + +Maybe TransportToken::cmd() const { + return JUST(CtrlTransportTokenView::Cast(this))->cmd(); +} + +Maybe TransportToken::set_src_rank(int64_t src_rank) { + CHECK_GE_OR_RETURN(src_rank, 0); + CHECK_LT_OR_RETURN(src_rank, GetMaxVal()); + src_rank_ = src_rank; + return Maybe::Ok(); +} + +Maybe TransportToken::set_dst_rank(int64_t dst_rank) { + CHECK_GE_OR_RETURN(dst_rank, 0); + CHECK_LT_OR_RETURN(dst_rank, GetMaxVal()); + dst_rank_ = dst_rank; + return Maybe::Ok(); +} + +TransportToken::operator uint64_t() const { + static_assert(sizeof(TransportToken) == sizeof(uint64_t), ""); + return *reinterpret_cast(this); +} + +TransportToken& TransportToken::operator++() { + TransportTokenType transport_token_type = type(); + if (transport_token_type == kDataTransportTokenType) { + UNIMPLEMENTED(); + } else if (transport_token_type == kMetaTransportTokenType) { + ++*CHECK_JUST(MetaTransportTokenView::MutCast(this)); + } else if (transport_token_type == kCtrlTransportTokenType) { + UNIMPLEMENTED(); + } else { + UNIMPLEMENTED(); + } + return *this; +} + +/*static*/ Maybe TransportToken::NewMetaTransportToken( + int32_t thread_consistent_unique_id, int32_t rank_group_level) { + TransportToken transport_token(kMetaTransportTokenType); + auto* view = JUST(MetaTransportTokenView::MutCast(&transport_token)); + JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); + JUST(view->set_rank_group_level(rank_group_level)); + return transport_token; +} + +/*static*/ Maybe TransportToken::NewCtrlTransportToken( + RankGroupCtrlCmd cmd, int32_t thread_consistent_unique_id, int32_t rank_group_level) { + TransportToken transport_token(kCtrlTransportTokenType); + auto* view = JUST(CtrlTransportTokenView::MutCast(&transport_token)); + JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); + JUST(view->set_rank_group_level(rank_group_level)); + view->set_cmd(cmd); + view->set_ctrl_seq_id(0); + return transport_token; +} + +} // namespace oneflow diff --git a/oneflow/core/framework/transport_token.h b/oneflow/core/framework/transport_token.h new file mode 100644 index 00000000000..9980ee358d6 --- /dev/null +++ b/oneflow/core/framework/transport_token.h @@ -0,0 +1,108 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ +#define ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ + +#include "oneflow/core/common/type_traits.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +const static int kTransportTokenTypeBit = 2; +const static int kTransportTokenThreadConsistentUIdBit = 3; +const static int kTransportTokenRankGroupLevelBit = 3; + +enum TransportTokenType { + // Begin + kDataTransportTokenType = 0, // e.g. for tensor data transportation + kMetaTransportTokenType, // e.g. for tensor meta checking + kCtrlTransportTokenType, // e.g. for rank_group or thread checking. see RankGroupCtrlCmd + kExtendedTransportTokenType, // for compatibility + // End + kTransportTokenTypeSize, +}; + +static_assert(kTransportTokenTypeSize <= (1 << kTransportTokenTypeBit), ""); + +enum RankGroupCtrlCmd { + // Begin + kRankGroupCtrlCmdInvalid = 0, + kRankGroupCtrlCmdSyncSymbolParallelDesc, + kRankGroupCtrlCmdSyncSymbolParallelDistribution, + kRankGroupCtrlCmdSyncSymbolConsistentTensorMeta, + kRankGroupCtrlCmdCheckRankGroupConsistency, + kRankGroupCtrlCmdCheckTensorConsistency, + kRankGroupCtrlCmdAll2AllSyncShape, + // End + kSizeOfRankGroupCtrlCmd +}; + +class TransportToken; + +template<> +struct IsScalarType final { + static const bool value = true; +}; + +class TransportToken final { + public: + TransportToken(const TransportToken&) = default; + TransportToken(TransportToken&) = default; + ~TransportToken() = default; + + static TransportToken NewDataTransportToken(); + static Maybe NewMetaTransportToken(); + static Maybe AcquireCtrlTransportToken(RankGroupCtrlCmd cmd); + Maybe ReleaseCtrlTransportToken() const; + + static constexpr size_t MaxNumberOfThreadConsistentUId() { + return (1 << kTransportTokenThreadConsistentUIdBit); + } + + // Getters + int64_t src_rank() const { return src_rank_; } + int64_t dst_rank() const { return dst_rank_; } + TransportTokenType type() const { return static_cast(type_); } + Maybe thread_consistent_unique_id() const; + Maybe rank_group_level() const; + Maybe cmd() const; + + // Setters + Maybe set_src_rank(int64_t src_rank); + Maybe set_dst_rank(int64_t dst_rank); + + operator uint64_t() const; + TransportToken& operator++(); + + private: + explicit TransportToken(TransportTokenType type); + + static Maybe NewMetaTransportToken(int32_t thread_consistent_unique_id, + int32_t rank_group_level); + static Maybe NewCtrlTransportToken(RankGroupCtrlCmd cmd, + int32_t thread_consistent_unique_id, + int32_t rank_group_level); + + uint16_t src_rank_; + uint16_t dst_rank_; + uint32_t type_ : 2; // TransportTokenType + uint32_t opaque_ids_ : 30; +}; +static_assert(sizeof(TransportToken) == sizeof(uint64_t), ""); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ diff --git a/oneflow/core/framework/rpc_util.cpp b/oneflow/core/framework/transport_util.cpp similarity index 56% rename from oneflow/core/framework/rpc_util.cpp rename to oneflow/core/framework/transport_util.cpp index 5cfda5b0aa9..a402eba5ff6 100644 --- a/oneflow/core/framework/rpc_util.cpp +++ b/oneflow/core/framework/transport_util.cpp @@ -15,8 +15,8 @@ limitations under the License. */ #include #include -#include "oneflow/core/framework/rpc_token.h" -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_token.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/transport/transport.h" #include "oneflow/core/thread/consistent_unique_id.h" @@ -26,7 +26,8 @@ limitations under the License. namespace oneflow { -/*static*/ Maybe RpcUtil::WaitUntilDoneOrTimeout(const AsyncRpcCtx& ctx, int64_t seconds) { +/*static*/ Maybe TransportUtil::WaitUntilDoneOrTimeout(const AsyncTransportCtx& ctx, + int64_t seconds) { const auto& start = std::chrono::steady_clock::now(); const auto& cond_cnt = ctx.flying_cnt(); while (*cond_cnt > 0) { @@ -35,17 +36,20 @@ namespace oneflow { CHECK_LT_OR_RETURN(elapsed_seconds.count(), seconds) << Error::TimeoutError() << "Timeout error at " << seconds << " seconds."; } - if (ctx.rpc_token().type() == kCtrlRpcTokenType) { JUST(ctx.rpc_token().ReleaseCtrlRpcToken()); } + if (ctx.transport_token().type() == kCtrlTransportTokenType) { + JUST(ctx.transport_token().ReleaseCtrlTransportToken()); + } return Maybe::Ok(); } namespace { -template (*SendOrRecv)(const RpcToken&, int64_t, void*, std::size_t, +template (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t, const std::function&), - Maybe (AsyncRpcCtx::*Prepare)(int64_t, void**, std::size_t*, std::function*)> -Maybe AccessToAllOtherRanks(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx) { + Maybe (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*, + std::function*)> +Maybe AccessToAllOtherRanks(Symbol rank_group, const TransportToken& token, + AsyncTransportCtx* ctx) { CHECK_OR_RETURN(rank_group->ContainingCurrentRank()); const auto& flying_cnt = ctx->flying_cnt(); JUST(rank_group->ForEachRank([&](int64_t rank) -> Maybe { @@ -65,11 +69,12 @@ Maybe AccessToAllOtherRanks(Symbol rank_group, const RpcToken& } template (RankGroup::*GetPrevOrNext)() const, - Maybe (*SendOrRecv)(const RpcToken&, int64_t, void*, std::size_t, + Maybe (*SendOrRecv)(const TransportToken&, int64_t, void*, std::size_t, const std::function&), - Maybe (AsyncRpcCtx::*Prepare)(int64_t, void**, std::size_t*, std::function*)> -Maybe AccessToNearbyRank(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx) { + Maybe (AsyncTransportCtx::*Prepare)(int64_t, void**, std::size_t*, + std::function*)> +Maybe AccessToNearbyRank(Symbol rank_group, const TransportToken& token, + AsyncTransportCtx* ctx) { if (rank_group->size() == 1) { return Maybe::Ok(); } const auto* rank_ranges_ptr = &*rank_group; int64_t rank = JUST((rank_ranges_ptr->*GetPrevOrNext)()); @@ -87,11 +92,11 @@ Maybe AccessToNearbyRank(Symbol rank_group, const RpcToken& tok return Maybe::Ok(); } -Maybe Send(const RpcToken& token, int64_t rank, void* buffer, std::size_t size, +Maybe Send(const TransportToken& token, int64_t rank, void* buffer, std::size_t size, const std::function& Callback) { #ifdef __linux__ auto* transport = JUST(GlobalMaybe()); - RpcToken transport_token(token); + TransportToken transport_token(token); JUST(transport_token.set_src_rank(GlobalProcessCtx::Rank())); JUST(transport_token.set_dst_rank(rank)); transport->Send(static_cast(transport_token), rank, buffer, size, Callback); @@ -102,11 +107,11 @@ Maybe Send(const RpcToken& token, int64_t rank, void* buffer, std::size_t #endif // __linux__ } -Maybe Recv(const RpcToken& token, int64_t rank, void* buffer, std::size_t size, +Maybe Recv(const TransportToken& token, int64_t rank, void* buffer, std::size_t size, const std::function& Callback) { #ifdef __linux__ auto* transport = JUST(GlobalMaybe()); - RpcToken transport_token(token); + TransportToken transport_token(token); JUST(transport_token.set_src_rank(rank)); JUST(transport_token.set_dst_rank(GlobalProcessCtx::Rank())); transport->Receive(static_cast(transport_token), rank, buffer, size, Callback); @@ -119,31 +124,37 @@ Maybe Recv(const RpcToken& token, int64_t rank, void* buffer, std::size_t } // namespace -/*static*/ Maybe RpcUtil::BroadcastToAllOtherRanks(Symbol rank_group, - const RpcToken& token, AsyncRpcCtx* ctx) { - JUST(AccessToAllOtherRanks<&Send, &AsyncRpcCtx::PrepareSendBufferAndCallback>(rank_group, token, - ctx)); +/*static*/ Maybe TransportUtil::BroadcastToAllOtherRanks(Symbol rank_group, + const TransportToken& token, + AsyncTransportCtx* ctx) { + JUST(AccessToAllOtherRanks<&Send, &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group, + token, ctx)); return Maybe::Ok(); } -/*static*/ Maybe RpcUtil::CollectFromAllOtherRanks(Symbol rank_group, - const RpcToken& token, AsyncRpcCtx* ctx) { - JUST(AccessToAllOtherRanks<&Recv, &AsyncRpcCtx::PrepareRecvBufferAndCallback>(rank_group, token, - ctx)); +/*static*/ Maybe TransportUtil::CollectFromAllOtherRanks(Symbol rank_group, + const TransportToken& token, + AsyncTransportCtx* ctx) { + JUST(AccessToAllOtherRanks<&Recv, &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group, + token, ctx)); return Maybe::Ok(); } -/*static*/ Maybe RpcUtil::SendToNextRankInRing(Symbol rank_group, - const RpcToken& token, AsyncRpcCtx* ctx) { - JUST(AccessToNearbyRank<&RankGroup::GetNextRankInRing, &Send, - &AsyncRpcCtx::PrepareSendBufferAndCallback>(rank_group, token, ctx)); +/*static*/ Maybe TransportUtil::SendToNextRankInRing(Symbol rank_group, + const TransportToken& token, + AsyncTransportCtx* ctx) { + JUST( + AccessToNearbyRank<&RankGroup::GetNextRankInRing, &Send, + &AsyncTransportCtx::PrepareSendBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } -/*static*/ Maybe RpcUtil::ReceiveFromPrevRankInRing(Symbol rank_group, - const RpcToken& token, AsyncRpcCtx* ctx) { - JUST(AccessToNearbyRank<&RankGroup::GetPrevRankInRing, &Recv, - &AsyncRpcCtx::PrepareRecvBufferAndCallback>(rank_group, token, ctx)); +/*static*/ Maybe TransportUtil::ReceiveFromPrevRankInRing(Symbol rank_group, + const TransportToken& token, + AsyncTransportCtx* ctx) { + JUST( + AccessToNearbyRank<&RankGroup::GetPrevRankInRing, &Recv, + &AsyncTransportCtx::PrepareRecvBufferAndCallback>(rank_group, token, ctx)); return Maybe::Ok(); } diff --git a/oneflow/core/framework/rpc_util.h b/oneflow/core/framework/transport_util.h similarity index 67% rename from oneflow/core/framework/rpc_util.h rename to oneflow/core/framework/transport_util.h index a30dace0c6d..20cd4e71cf1 100644 --- a/oneflow/core/framework/rpc_util.h +++ b/oneflow/core/framework/transport_util.h @@ -19,17 +19,17 @@ limitations under the License. #include #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" -#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/transport_token.h" namespace oneflow { -class AsyncRpcCtx { +class AsyncTransportCtx { public: - explicit AsyncRpcCtx(const RpcToken& rpc_token) - : rpc_token_(rpc_token), flying_cnt_(new std::atomic(0)) {} - virtual ~AsyncRpcCtx() = default; + explicit AsyncTransportCtx(const TransportToken& transport_token) + : transport_token_(transport_token), flying_cnt_(new std::atomic(0)) {} + virtual ~AsyncTransportCtx() = default; - const RpcToken& rpc_token() const { return rpc_token_; } + const TransportToken& transport_token() const { return transport_token_; } std::shared_ptr> flying_cnt() const { return flying_cnt_; } virtual Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, @@ -39,46 +39,49 @@ class AsyncRpcCtx { std::function* Callback) = 0; private: - RpcToken rpc_token_; + TransportToken transport_token_; std::shared_ptr> flying_cnt_; }; -class NaiveAsyncRpcCtx final : public AsyncRpcCtx { +class NaiveAsyncTransportCtx final : public AsyncTransportCtx { public: - NaiveAsyncRpcCtx( - const RpcToken& rpc_token, + NaiveAsyncTransportCtx( + const TransportToken& transport_token, const std::function(void**, std::size_t*, std::function*)>& PrepareSend, const std::function(void**, std::size_t*, std::function*)>& PrepareRecv) - : AsyncRpcCtx(rpc_token), prepare_send_(PrepareSend), prepare_recv_(PrepareRecv) {} + : AsyncTransportCtx(transport_token), + prepare_send_(PrepareSend), + prepare_recv_(PrepareRecv) {} - NaiveAsyncRpcCtx( - const RpcToken& rpc_token, + NaiveAsyncTransportCtx( + const TransportToken& transport_token, const std::function(void**, std::size_t*, std::function*)>& PrepareSend, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareRecvWithRank) - : AsyncRpcCtx(rpc_token), + : AsyncTransportCtx(transport_token), prepare_send_(PrepareSend), prepare_recv_with_rank_(PrepareRecvWithRank) {} - NaiveAsyncRpcCtx( - const RpcToken& rpc_token, + NaiveAsyncTransportCtx( + const TransportToken& transport_token, const std::function(int64_t, void**, std::size_t*, std::function*)>& PrepareSendWithRank, const std::function(void**, std::size_t*, std::function*)>& PrepareRecv) - : AsyncRpcCtx(rpc_token), + : AsyncTransportCtx(transport_token), prepare_send_with_rank_(PrepareSendWithRank), prepare_recv_(PrepareRecv) {} - NaiveAsyncRpcCtx(const RpcToken& rpc_token, - const std::function(int64_t, void**, std::size_t*, - std::function*)>& PrepareSendWithRank, - const std::function(int64_t, void**, std::size_t*, - std::function*)>& PrepareRecvWithRank) - : AsyncRpcCtx(rpc_token), + NaiveAsyncTransportCtx( + const TransportToken& transport_token, + const std::function(int64_t, void**, std::size_t*, std::function*)>& + PrepareSendWithRank, + const std::function(int64_t, void**, std::size_t*, std::function*)>& + PrepareRecvWithRank) + : AsyncTransportCtx(transport_token), prepare_send_with_rank_(PrepareSendWithRank), prepare_recv_with_rank_(PrepareRecvWithRank) {} - ~NaiveAsyncRpcCtx() override = default; + ~NaiveAsyncTransportCtx() override = default; Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, std::function* Callback) override { @@ -103,22 +106,22 @@ class NaiveAsyncRpcCtx final : public AsyncRpcCtx { class RankGroup; -struct RpcUtil final { +struct TransportUtil final { static int64_t TimeoutSeconds() { return 60 * 5; } - static Maybe WaitUntilDoneOrTimeout(const AsyncRpcCtx& ctx, int64_t seconds); + static Maybe WaitUntilDoneOrTimeout(const AsyncTransportCtx& ctx, int64_t seconds); - static Maybe SendToNextRankInRing(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx); + static Maybe SendToNextRankInRing(Symbol rank_group, const TransportToken& token, + AsyncTransportCtx* ctx); - static Maybe ReceiveFromPrevRankInRing(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx); + static Maybe ReceiveFromPrevRankInRing(Symbol rank_group, + const TransportToken& token, AsyncTransportCtx* ctx); - static Maybe BroadcastToAllOtherRanks(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx); + static Maybe BroadcastToAllOtherRanks(Symbol rank_group, + const TransportToken& token, AsyncTransportCtx* ctx); - static Maybe CollectFromAllOtherRanks(Symbol rank_group, const RpcToken& token, - AsyncRpcCtx* ctx); + static Maybe CollectFromAllOtherRanks(Symbol rank_group, + const TransportToken& token, AsyncTransportCtx* ctx); }; } // namespace oneflow diff --git a/oneflow/core/functional/impl/consistent_cast.cpp b/oneflow/core/functional/impl/consistent_cast.cpp index 09d05975f5d..3c4d0ad6461 100644 --- a/oneflow/core/functional/impl/consistent_cast.cpp +++ b/oneflow/core/functional/impl/consistent_cast.cpp @@ -30,8 +30,8 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" #include "oneflow/core/job/rank_group_scope.h" -#include "oneflow/core/framework/rpc_token.h" -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_token.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/flat_shape.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/common/balanced_splitter.h" @@ -45,12 +45,13 @@ namespace impl { namespace { Maybe>> All2AllSyncShape(const Shape& shape) { - const auto& rpc_token = JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdAll2AllSyncShape)); + const auto& transport_token = + JUST(TransportToken::AcquireCtrlTransportToken(kRankGroupCtrlCmdAll2AllSyncShape)); const auto& send_buffer = JUST(FlatShape::New(shape)); const auto& map = std::make_shared>>(); map->emplace(GlobalProcessCtx::Rank(), send_buffer); - NaiveAsyncRpcCtx ctx( - rpc_token, + NaiveAsyncTransportCtx ctx( + transport_token, [send_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { *buffer = send_buffer.get(); *size = sizeof(FlatShape); @@ -68,9 +69,9 @@ Maybe>> All2AllSyncShape(const Shape return Maybe::Ok(); }); const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); - JUST(RpcUtil::BroadcastToAllOtherRanks(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::CollectFromAllOtherRanks(rank_group, rpc_token, &ctx)); - JUST(RpcUtil::WaitUntilDoneOrTimeout(ctx, RpcUtil::TimeoutSeconds())); + JUST(TransportUtil::BroadcastToAllOtherRanks(rank_group, transport_token, &ctx)); + JUST(TransportUtil::CollectFromAllOtherRanks(rank_group, transport_token, &ctx)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); return map; } diff --git a/oneflow/core/thread/consistent_unique_id.cpp b/oneflow/core/thread/consistent_unique_id.cpp index 6fb5ce3364e..377177071ec 100644 --- a/oneflow/core/thread/consistent_unique_id.cpp +++ b/oneflow/core/thread/consistent_unique_id.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/thread/consistent_unique_id.h" #include "oneflow/core/common/util.h" -#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/transport_util.h" #include "oneflow/core/common/container_util.h" namespace oneflow { @@ -34,7 +34,7 @@ class ConsistentUniqueIdStorage final { Maybe Emplace(int64_t id, const std::string& debug_string) { std::unique_lock lock(mutex_); - CHECK_LE_OR_RETURN(id2debug_string_.size(), RpcToken::MaxNumberOfThreadConsistentUId()); + CHECK_LE_OR_RETURN(id2debug_string_.size(), TransportToken::MaxNumberOfThreadConsistentUId()); for (const auto& pair : id2debug_string_) { CHECK_NE_OR_RETURN(debug_string, pair.second); } CHECK_OR_RETURN(id2debug_string_.emplace(id, debug_string).second); return Maybe::Ok(); diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index ea514371428..790eafe921b 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -80,7 +80,7 @@ def is_deprecated(func_or_class): env_util.SetDefaultMultiClientEnvVars() oneflow._oneflow_internal.SetIsMultiClient(True) env_util.api_env_init() -oneflow._oneflow_internal.InitDefaultConsistentRpcTokenScope() +oneflow._oneflow_internal.InitDefaultConsistentTransportTokenScope() session_ctx.OpenDefaultSession( MultiClientSession(oneflow._oneflow_internal.NewSessionId()) )