From 0a44b54ebaf5d439d1239333e045d758a3911b6e Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Sat, 31 Jul 2021 08:10:32 +0800 Subject: [PATCH] New sync consistent meta info (#5634) * rebase * check in gen py * merge master and fix bugs * address pr comments * address pr comments * auto format by CI * rebase * address pr comments * auto format by CI * functional python_arg * reuse ctrl rpc token for avoiding long time timeout waiting. * fix compiler complaints * auto format by CI * auto format by CI * remove unused files * fix return type error on gcc 4.8.5 Signed-off-by: daquexian * auto format by CI * fix return type error in xrt Signed-off-by: daquexian * fix tick ibn sbp signature * auto format by CI Co-authored-by: tsai Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: daquexian --- oneflow/api/python/framework/tensor.cpp | 20 ++ .../python/rpc/consistent_rpc_token_scope.cpp | 60 ++++ oneflow/api/python/rpc/rank_group.cpp | 47 +++ .../api/python/symbol/placement_symbol.cpp | 1 + oneflow/core/common/error.cpp | 6 + oneflow/core/common/error.h | 1 + oneflow/core/common/error.proto | 3 + oneflow/core/common/exception.h | 1 + oneflow/core/common/flat_shape.cpp | 36 +++ oneflow/core/common/flat_shape.h | 44 +++ oneflow/core/common/maybe.h | 6 + oneflow/core/control/ctrl_bootstrap.proto | 4 - .../init_symbol_instruction_type_test.cpp | 12 - .../eager/lazy_job_instruction_type_test.cpp | 15 - .../eager/opkernel_instruction_type_test.cpp | 23 -- .../transport_blob_instruction_type_test.cpp | 9 - .../consistent_tensor_infer_cache.cpp | 34 ++- .../framework/consistent_tensor_infer_cache.h | 13 +- oneflow/core/framework/interpreter_test.cpp | 12 - oneflow/core/framework/op_expr.cpp | 4 +- .../eager_consistent_op_interpreter.cpp | 3 + .../core/framework/rank_group_rpc_util.cpp | 54 ++++ oneflow/core/framework/rank_group_rpc_util.h | 32 ++ oneflow/core/framework/rpc_token.cpp | 285 ++++++++++++++++++ oneflow/core/framework/rpc_token.h | 106 +++++++ oneflow/core/framework/rpc_util.cpp | 134 ++++++++ oneflow/core/framework/rpc_util.h | 82 +++++ .../sync_symbol_consistent_tensor_meta.cpp | 100 ++++++ .../sync_symbol_consistent_tensor_meta.h | 34 +++ .../framework/sync_symbol_parallel_desc.cpp | 99 ++++++ .../framework/sync_symbol_parallel_desc.h | 32 ++ .../sync_symbol_parallel_distribution.cpp | 173 +++++++++++ .../sync_symbol_parallel_distribution.h | 36 +++ oneflow/core/framework/synced_symbol_map.cpp | 25 ++ oneflow/core/framework/synced_symbol_map.h | 74 +++++ oneflow/core/framework/tensor.cpp | 1 + oneflow/core/framework/tensor.h | 17 +- oneflow/core/framework/tensor_impl.h | 19 +- oneflow/core/framework/tensor_rpc_util.cpp | 149 +++++++++ oneflow/core/framework/tensor_rpc_util.h | 56 ++++ oneflow/core/job/id_manager_test.cpp | 7 +- oneflow/core/job/parallel_desc.cpp | 1 + oneflow/core/job/parallel_desc.h | 3 + oneflow/core/job/parallel_desc_test.cpp | 36 +-- oneflow/core/job/rank_group.cpp | 94 ++++++ oneflow/core/job/rank_group.h | 71 +++++ oneflow/core/job/rank_group_scope.cpp | 77 +++++ oneflow/core/job/rank_group_scope.h | 52 ++++ oneflow/core/job/rank_group_scope_test.cpp | 64 ++++ oneflow/core/job/rank_group_test.cpp | 62 ++++ oneflow/core/rpc/lib/global_process_ctx.cpp | 3 - oneflow/core/thread/consistent_unique_id.cpp | 78 +++++ oneflow/core/thread/consistent_unique_id.h | 31 ++ oneflow/core/vm/nop_stream_type_test.cpp | 24 +- .../core/vm/object_instruction_type_test.cpp | 22 +- .../vm/sequential_instruction_type_test.cpp | 8 +- oneflow/core/vm/test_util.cpp | 6 +- python/oneflow/__init__.py | 1 + 58 files changed, 2345 insertions(+), 157 deletions(-) create mode 100644 oneflow/api/python/rpc/consistent_rpc_token_scope.cpp create mode 100644 oneflow/api/python/rpc/rank_group.cpp create mode 100644 oneflow/core/common/flat_shape.cpp create mode 100644 oneflow/core/common/flat_shape.h create mode 100644 oneflow/core/framework/rank_group_rpc_util.cpp create mode 100644 oneflow/core/framework/rank_group_rpc_util.h create mode 100644 oneflow/core/framework/rpc_token.cpp create mode 100644 oneflow/core/framework/rpc_token.h create mode 100644 oneflow/core/framework/rpc_util.cpp create mode 100644 oneflow/core/framework/rpc_util.h create mode 100644 oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp create mode 100644 oneflow/core/framework/sync_symbol_consistent_tensor_meta.h create mode 100644 oneflow/core/framework/sync_symbol_parallel_desc.cpp create mode 100644 oneflow/core/framework/sync_symbol_parallel_desc.h create mode 100644 oneflow/core/framework/sync_symbol_parallel_distribution.cpp create mode 100644 oneflow/core/framework/sync_symbol_parallel_distribution.h create mode 100644 oneflow/core/framework/synced_symbol_map.cpp create mode 100644 oneflow/core/framework/synced_symbol_map.h create mode 100644 oneflow/core/framework/tensor_rpc_util.cpp create mode 100644 oneflow/core/framework/tensor_rpc_util.h create mode 100644 oneflow/core/job/rank_group.cpp create mode 100644 oneflow/core/job/rank_group.h create mode 100644 oneflow/core/job/rank_group_scope.cpp create mode 100644 oneflow/core/job/rank_group_scope.h create mode 100644 oneflow/core/job/rank_group_scope_test.cpp create mode 100644 oneflow/core/job/rank_group_test.cpp create mode 100644 oneflow/core/thread/consistent_unique_id.cpp create mode 100644 oneflow/core/thread/consistent_unique_id.h diff --git a/oneflow/api/python/framework/tensor.cpp b/oneflow/api/python/framework/tensor.cpp index f0b2744af39..4cf1129259d 100644 --- a/oneflow/api/python/framework/tensor.cpp +++ b/oneflow/api/python/framework/tensor.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/common/tensor_buffer.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_method.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/stride.h" @@ -247,6 +248,13 @@ void ApiRegisterTensorHook(const std::shared_ptr& self, const AutogradMe return RegisterTensorHook(self, hook).GetOrThrow(); } +Maybe CheckConsistentTensorMeta(const one::Tensor& tensor, int64_t seconds) { + const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(tensor)); + JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); + JUST(ctx->Check()); + return Maybe::Ok(); +} + bool ApiIsContiguous(const std::shared_ptr& tensor) { return IsContiguous(tensor).GetOrThrow(); } @@ -403,6 +411,18 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes) .def_property_readonly("device", &TensorGetDevice) .def_property_readonly("data", &Tensor::data) + .def("rpc_token", + [](const one::Tensor& tensor) -> int64_t { + return static_cast(tensor.rpc_token().GetOrThrow()); + }) + .def("check_meta_consistency", + [](const one::Tensor& tensor) { + return CheckConsistentTensorMeta(tensor, 60 * 5).GetOrThrow(); + }) + .def("check_meta_consistency", + [](const one::Tensor& tensor, int64_t seconds) { + return CheckConsistentTensorMeta(tensor, seconds).GetOrThrow(); + }) #define DEFINE_TENSOR_METHOD(T, type_proto) \ .def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy) \ .def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy) diff --git a/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp b/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp new file mode 100644 index 00000000000..72ea1da4020 --- /dev/null +++ b/oneflow/api/python/rpc/consistent_rpc_token_scope.cpp @@ -0,0 +1,60 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include +#include "oneflow/api/python/of_api_registry.h" +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/common/symbol.h" + +namespace py = pybind11; + +namespace oneflow { + +namespace { + +Maybe InitConsistentRpcTokenScope(const std::string& thread_tag, + int64_t thread_consistent_uid, + Symbol rank_group) { + JUST(SetThisThreadConsistentUniqueId(thread_consistent_uid, thread_tag)); + static thread_local const auto& init_rank_group_scope = + JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group)); + // no unused warning for `init_rank_group_scope`. + (void)(init_rank_group_scope); + return Maybe::Ok(); +} + +Maybe InitConsistentRpcTokenScope(const std::string& thread_tag, + int64_t thread_consistent_uid) { + const auto& rank_group = JUST(RankGroup::DefaultRankGroup()); + JUST(InitConsistentRpcTokenScope(thread_tag, thread_consistent_uid, rank_group)); + return Maybe::Ok(); +} + +void ApiInitDefaultConsistentRpcTokenScope() { + return InitConsistentRpcTokenScope("main", 0).GetOrThrow(); +} + +} // namespace + +ONEFLOW_API_PYBIND11_MODULE("", m) { + m.def("InitDefaultConsistentRpcTokenScope", &ApiInitDefaultConsistentRpcTokenScope); +} + +} // namespace oneflow diff --git a/oneflow/api/python/rpc/rank_group.cpp b/oneflow/api/python/rpc/rank_group.cpp new file mode 100644 index 00000000000..e7414e45c64 --- /dev/null +++ b/oneflow/api/python/rpc/rank_group.cpp @@ -0,0 +1,47 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include +#include "oneflow/api/python/of_api_registry.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/common/symbol.h" + +namespace py = pybind11; + +namespace oneflow { + +namespace { + +Maybe CheckCurrentRankGroupConsistency(int64_t seconds) { + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + const auto& ctx = JUST(CheckRpcToken(rank_group)); + JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds)); + return Maybe::Ok(); +} + +} // namespace + +ONEFLOW_API_PYBIND11_MODULE("", m) { + m.def("check_current_rank_group_consistency", + [](int64_t seconds) { return CheckCurrentRankGroupConsistency(seconds).GetOrThrow(); }); + m.def("check_current_rank_group_consistency", + []() { return CheckCurrentRankGroupConsistency(60 * 5).GetOrThrow(); }); +} + +} // namespace oneflow diff --git a/oneflow/api/python/symbol/placement_symbol.cpp b/oneflow/api/python/symbol/placement_symbol.cpp index bf255f49470..233084f68c7 100644 --- a/oneflow/api/python/symbol/placement_symbol.cpp +++ b/oneflow/api/python/symbol/placement_symbol.cpp @@ -193,6 +193,7 @@ struct PlacementSymbolExportUtil { return placement_str; } }; + } // namespace ONEFLOW_API_PYBIND11_MODULE("", m) { diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp index b589b33552d..60e38e9ceb0 100644 --- a/oneflow/core/common/error.cpp +++ b/oneflow/core/common/error.cpp @@ -76,6 +76,12 @@ Error Error::IndexError() { return error; } +Error Error::TimeoutError() { + auto error = std::make_shared(); + error->mutable_timeout_error(); + return error; +} + Error Error::JobNameExistError() { auto error = std::make_shared(); error->mutable_job_name_exist_error(); diff --git a/oneflow/core/common/error.h b/oneflow/core/common/error.h index 5a878e3f9c6..aedfe8133d4 100644 --- a/oneflow/core/common/error.h +++ b/oneflow/core/common/error.h @@ -44,6 +44,7 @@ class Error final { static Error DeviceTagNotFoundError(); static Error ValueError(const std::string& error_summary); static Error IndexError(); + static Error TimeoutError(); static Error JobNameExistError(); static Error JobNameEmptyError(); static Error JobNameNotEqualError(); diff --git a/oneflow/core/common/error.proto b/oneflow/core/common/error.proto index 81dd7b62257..66217665ff9 100644 --- a/oneflow/core/common/error.proto +++ b/oneflow/core/common/error.proto @@ -127,6 +127,8 @@ message ValueError {} message IndexError {} +message TimeoutError {} + message ErrorProto { optional string error_summary = 1 [default = ""]; optional string msg = 2 [default = ""]; @@ -148,6 +150,7 @@ message ErrorProto { DeviceTagNotFoundError device_tag_not_found_error = 26; ValueError value_error = 27; IndexError index_error = 28; + TimeoutError timeout_error = 29; JobNameExistError job_name_exist_error = 100; JobNameEmptyError job_name_empty_error = 101; JobNameNotEqualError job_name_not_equal_error = 102; diff --git a/oneflow/core/common/exception.h b/oneflow/core/common/exception.h index e1e1dee1fb7..2fbd433ae3b 100644 --- a/oneflow/core/common/exception.h +++ b/oneflow/core/common/exception.h @@ -70,6 +70,7 @@ class Exception : public std::exception { OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong) \ OF_PP_MAKE_TUPLE_SEQ(Value) \ OF_PP_MAKE_TUPLE_SEQ(Index) \ + OF_PP_MAKE_TUPLE_SEQ(Timeout) \ OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch) #define DEFINE_EXCEPTION_CLASS(cls) \ diff --git a/oneflow/core/common/flat_shape.cpp b/oneflow/core/common/flat_shape.cpp new file mode 100644 index 00000000000..78374baec00 --- /dev/null +++ b/oneflow/core/common/flat_shape.cpp @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/flat_shape.h" +#include "oneflow/core/common/shape.h" + +namespace oneflow { + +Maybe FlatShape::Init(const std::shared_ptr& shape) { + CHECK_LE_OR_RETURN(shape->NumAxes(), SHAPE_MAX_AXIS_SIZE); + this->set_num_axes(shape->NumAxes()); + for (int i = 0; i < this->num_axes(); ++i) { *this->mutable_dim()->Mutable(i) = shape->At(i); } + return Maybe::Ok(); +} + +Maybe FlatShape::Check(const std::shared_ptr& shape) const { + CHECK_EQ_OR_RETURN(this->num_axes(), shape->NumAxes()); + for (int i = 0; i < this->num_axes(); ++i) { + CHECK_EQ_OR_RETURN(this->dim().Get(i), shape->At(i)); + } + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/common/flat_shape.h b/oneflow/core/common/flat_shape.h new file mode 100644 index 00000000000..8b70e31fa6b --- /dev/null +++ b/oneflow/core/common/flat_shape.h @@ -0,0 +1,44 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ +#define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ + +#include +#include "oneflow/core/object_msg/flat_msg.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape_vec.h" + +namespace oneflow { + +class Shape; + +// clang-format off + +FLAT_MSG_BEGIN(FlatShape); + // Methods + OF_PUBLIC Maybe Init(const std::shared_ptr& shape); + OF_PUBLIC Maybe Check(const std::shared_ptr& shape) const; + + // Fields + FLAT_MSG_DEFINE_OPTIONAL(int64_t, num_axes); + FLAT_MSG_DEFINE_REPEATED(int64_t, dim, SHAPE_MAX_AXIS_SIZE); +FLAT_MSG_END(FlatShape); + +// clang-format on + +} // namespace oneflow + +#endif // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_ diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index 00435c531ec..33a028394a2 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -149,12 +149,18 @@ class Maybe::value>::type> fina SharedOrScalar error_or_scalar_; }; +inline const std::shared_ptr& UninitializedValueError() { + static thread_local const auto& error = Error::ValueError("uninitialized value").error_proto(); + return error; +} + template class Maybe::value>::type> final { public: Maybe(T data) : error_or_scalar_(data) {} Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); } Maybe(const std::shared_ptr& error) : error_or_scalar_(error) { CheckError(); } + Maybe() : error_or_scalar_(UninitializedValueError()) {} Maybe(const Maybe&) = default; Maybe(Maybe&&) = default; ~Maybe() = default; diff --git a/oneflow/core/control/ctrl_bootstrap.proto b/oneflow/core/control/ctrl_bootstrap.proto index 6918bb69ba6..ebf6cbf622e 100644 --- a/oneflow/core/control/ctrl_bootstrap.proto +++ b/oneflow/core/control/ctrl_bootstrap.proto @@ -20,7 +20,3 @@ message BootstrapConf { optional int32 ctrl_port = 5 [default = -1]; optional int64 node_size = 6 [default = -1]; } - -message NumProcessPerNode { - required int64 value = 1; -} diff --git a/oneflow/core/eager/init_symbol_instruction_type_test.cpp b/oneflow/core/eager/init_symbol_instruction_type_test.cpp index 4a3dc93c191..8f2e80e67fe 100644 --- a/oneflow/core/eager/init_symbol_instruction_type_test.cpp +++ b/oneflow/core/eager/init_symbol_instruction_type_test.cpp @@ -39,14 +39,6 @@ namespace oneflow { namespace vm { namespace test { -namespace { - -void InitNumProcessPerNode() { Global::New()->set_value(1); } - -void DestroyNumProcessPerNode() { Global::Delete(); } - -} // namespace - using InstructionMsgList = OBJECT_MSG_LIST(vm::InstructionMsg, instr_msg_link); template @@ -72,25 +64,21 @@ void TestInitSymbolInstructionType(const std::string& instr_type_name) { } TEST(InitSymbolInstructionType, job_desc) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); #else vm::TestResourceDescScope resource_scope(0, 1); #endif TestInitSymbolInstructionType("InitJobDescSymbol"); - DestroyNumProcessPerNode(); } TEST(InitSymbolInstructionType, operator_conf) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); #else vm::TestResourceDescScope resource_scope(0, 1); #endif TestInitSymbolInstructionType("InitOperatorConfSymbol"); - DestroyNumProcessPerNode(); } } // namespace test diff --git a/oneflow/core/eager/lazy_job_instruction_type_test.cpp b/oneflow/core/eager/lazy_job_instruction_type_test.cpp index f8fa7dcf319..d2245ec8369 100644 --- a/oneflow/core/eager/lazy_job_instruction_type_test.cpp +++ b/oneflow/core/eager/lazy_job_instruction_type_test.cpp @@ -40,17 +40,6 @@ namespace oneflow { namespace vm { namespace test { -namespace { - -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } - -} // namespace - using InstructionMsgList = OBJECT_MSG_LIST(vm::InstructionMsg, instr_msg_link); class NoArgNoRetMockNNGraph : public NNGraphIf { @@ -73,7 +62,6 @@ class NoArgNoRetMockNNGraph : public NNGraphIf { }; TEST(RunLazyJobInstructionType, simple) { - InitNumProcessPerNode(); vm::TestResourceDescScope resource_scope(0, 1); auto vm_desc = ObjectMsgPtr::New(vm::TestUtil::NewVmResourceDesc().Get()); vm::TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"RunLazyJob"}); @@ -122,11 +110,9 @@ TEST(RunLazyJobInstructionType, simple) { leave_thread.join(); enter_thread.join(); Global>>::Delete(); - DestroyNumProcessPerNode(); } TEST(RunLazyJobInstructionType, wait_for_another_job_finished) { - InitNumProcessPerNode(); vm::TestResourceDescScope resource_scope(0, 1); auto vm_desc = ObjectMsgPtr::New(vm::TestUtil::NewVmResourceDesc().Get()); vm::TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"RunLazyJob"}); @@ -247,7 +233,6 @@ TEST(RunLazyJobInstructionType, wait_for_another_job_finished) { enter_thread0.join(); enter_thread1.join(); Global>>::Delete(); - DestroyNumProcessPerNode(); } } // namespace test diff --git a/oneflow/core/eager/opkernel_instruction_type_test.cpp b/oneflow/core/eager/opkernel_instruction_type_test.cpp index 85cbb7f1adb..4e3d1c842a2 100644 --- a/oneflow/core/eager/opkernel_instruction_type_test.cpp +++ b/oneflow/core/eager/opkernel_instruction_type_test.cpp @@ -40,17 +40,6 @@ namespace oneflow { namespace vm { namespace test { -namespace { - -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } - -} // namespace - using InstructionMsgList = OBJECT_MSG_LIST(vm::InstructionMsg, instr_msg_link); int64_t NewJobDescSymbol(InstructionMsgList* list, @@ -118,7 +107,6 @@ int64_t InitOpKernelObject(InstructionMsgList* list, } TEST(OpkernelInstructionType, new_opkernel) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); const std::string device_tag = "gpu"; @@ -142,11 +130,9 @@ TEST(OpkernelInstructionType, new_opkernel) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } TEST(OpkernelInstructionType, delete_opkernel) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); const std::string device_tag = "gpu"; @@ -173,11 +159,9 @@ TEST(OpkernelInstructionType, delete_opkernel) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } TEST(OpkernelInstructionType, call_opkernel) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); const std::string device_tag = "gpu"; @@ -220,12 +204,10 @@ TEST(OpkernelInstructionType, call_opkernel) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } #ifdef WITH_CUDA TEST(OpkernelInstructionType, consecutive_opkernel_calls) { - InitNumProcessPerNode(); vm::TestResourceDescScope resource_scope(1, 1); InstructionMsgList list; int64_t in_id = vm::TestUtil::NewStringSymbol(&list, "in_0"); @@ -300,12 +282,10 @@ TEST(OpkernelInstructionType, consecutive_opkernel_calls) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } #endif TEST(OpkernelInstructionType, stateless_call_opkernel) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope resource_scope(1, 1); const std::string device_tag = "gpu"; @@ -351,12 +331,10 @@ TEST(OpkernelInstructionType, stateless_call_opkernel) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } #ifdef WITH_CUDA TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { - InitNumProcessPerNode(); vm::TestResourceDescScope resource_scope(1, 1); InstructionMsgList list; int64_t job_desc_id = NewJobDescSymbol(&list, std::make_shared()); @@ -430,7 +408,6 @@ TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } #endif diff --git a/oneflow/core/eager/transport_blob_instruction_type_test.cpp b/oneflow/core/eager/transport_blob_instruction_type_test.cpp index 78c1d887785..ef7474ee7ce 100644 --- a/oneflow/core/eager/transport_blob_instruction_type_test.cpp +++ b/oneflow/core/eager/transport_blob_instruction_type_test.cpp @@ -251,18 +251,10 @@ class SendRecvUtil { std::string recv_instr_name_; }; -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } - } // namespace #ifdef __linux__ TEST(SendReceiveInstructionType, naive) { - InitNumProcessPerNode(); #ifdef WITH_CUDA vm::TestResourceDescScope scope(1, 1, 2); #else @@ -307,7 +299,6 @@ TEST(SendReceiveInstructionType, naive) { ASSERT_TRUE(token2recv_request.find(header_token) != token2recv_request.end()); ASSERT_TRUE(token2send_request.find(body_token) != token2send_request.end()); ASSERT_TRUE(token2recv_request.find(body_token) != token2recv_request.end()); - DestroyNumProcessPerNode(); } #endif // __linux__ diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.cpp b/oneflow/core/framework/consistent_tensor_infer_cache.cpp index 07c4865e6a4..742268cf804 100644 --- a/oneflow/core/framework/consistent_tensor_infer_cache.cpp +++ b/oneflow/core/framework/consistent_tensor_infer_cache.cpp @@ -23,21 +23,35 @@ limitations under the License. namespace oneflow { namespace one { +namespace { + +bool OptionalEqual(const Optional>& lhs, + const Optional>& rhs) { + if (lhs.has_value() != rhs.has_value()) { return false; } + if (!lhs.has_value()) { return true; } + return CHECK_JUST(lhs.value()) == CHECK_JUST(rhs.value()); +} + +} // namespace + size_t InputConsistentTensorMeta::hash_value() const { - return std::hash>()(tensor_meta()) - ^ std::hash>()( - consumer_parallel_distribution_constraint()); + size_t hash_value = std::hash>()(tensor_meta()); + if (consumer_parallel_distribution_constraint().has_value()) { + hash_value ^= std::hash>()( + CHECK_JUST(consumer_parallel_distribution_constraint().value())); + } + return hash_value; } bool InputConsistentTensorMeta::operator==(const InputConsistentTensorMeta& other) const { return this->tensor_meta() == other.tensor_meta() - && this->consumer_parallel_distribution_constraint() - == other.consumer_parallel_distribution_constraint(); + && OptionalEqual(this->consumer_parallel_distribution_constraint(), + other.consumer_parallel_distribution_constraint()); } void InputConsistentTensorMeta::assign( Symbol tensor_meta, - Symbol consumer_parallel_distribution_constraint) { + const Optional>& consumer_parallel_distribution_constraint) { tensor_meta_ = tensor_meta; consumer_parallel_distribution_constraint_ = consumer_parallel_distribution_constraint; } @@ -77,7 +91,9 @@ Maybe ConsistentTensorMetaInferArgs::MakeParallelDistributionConstraints( for (int i = 0; i < input_arg_tuple.size(); ++i) { const auto& constaint = input_consistent_tensor_metas_.at(i).consumer_parallel_distribution_constraint(); - if (constaint) { (*map)[input_arg_tuple.indexed_bns().at(i)] = *constaint; } + if (constaint.has_value()) { + (*map)[input_arg_tuple.indexed_bns().at(i)] = *CHECK_JUST(constaint.value()); + } } return Maybe::Ok(); } @@ -136,8 +152,8 @@ Maybe ConsistentTensorMetaInferArgs::InitInputConsistentTensorMetas( for (int i = 0; i < input_tensors.size(); ++i) { const auto& tensor = *input_tensors.at(i); const auto& tensor_meta = JUST(tensor.consistent_tensor_meta()); - const auto& constraints = JUST(tensor.consumer_parallel_distribution_constraint()); - input_consistent_tensor_metas_.at(i).assign(tensor_meta, constraints); + const auto& constraint = JUST(tensor.consumer_parallel_distribution_constraint()); + input_consistent_tensor_metas_.at(i).assign(tensor_meta, constraint); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/consistent_tensor_infer_cache.h b/oneflow/core/framework/consistent_tensor_infer_cache.h index 0520f3511b7..436bd7f5e72 100644 --- a/oneflow/core/framework/consistent_tensor_infer_cache.h +++ b/oneflow/core/framework/consistent_tensor_infer_cache.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/symbol.h" #include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/tensor_meta.h" #include "oneflow/core/register/blob_desc.h" @@ -41,7 +42,7 @@ class InputConsistentTensorMeta final { InputConsistentTensorMeta() : tensor_meta_(), consumer_parallel_distribution_constraint_() {} InputConsistentTensorMeta( Symbol tensor_meta, - Symbol consumer_parallel_distribution_constraint) + const Optional>& consumer_parallel_distribution_constraint) : tensor_meta_(tensor_meta), consumer_parallel_distribution_constraint_(consumer_parallel_distribution_constraint) {} @@ -52,15 +53,17 @@ class InputConsistentTensorMeta final { size_t hash_value() const; bool operator==(const InputConsistentTensorMeta& other) const; Symbol tensor_meta() const { return tensor_meta_; } - Symbol consumer_parallel_distribution_constraint() const { + const Optional>& consumer_parallel_distribution_constraint() + const { return consumer_parallel_distribution_constraint_; } - void assign(Symbol tensor_meta, - Symbol consumer_parallel_distribution_constraint); + void assign( + Symbol tensor_meta, + const Optional>& consumer_parallel_distribution_constraint); private: Symbol tensor_meta_; - Symbol consumer_parallel_distribution_constraint_; + Optional> consumer_parallel_distribution_constraint_; }; class TensorTuple; diff --git a/oneflow/core/framework/interpreter_test.cpp b/oneflow/core/framework/interpreter_test.cpp index 95d0787a41c..397bb40e3f0 100644 --- a/oneflow/core/framework/interpreter_test.cpp +++ b/oneflow/core/framework/interpreter_test.cpp @@ -30,19 +30,9 @@ namespace test { namespace { -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } - class TestVirtualMachineScope { public: TestVirtualMachineScope(int64_t gpu_device_num, int64_t cpu_device_num) { - InitNumProcessPerNode(); - Global::New(); - Global::Get()->set_rank(0); test_resource_desc_scope_.reset(new vm::TestResourceDescScope(gpu_device_num, cpu_device_num)); virtual_machine_scope_.reset( new vm::VirtualMachineScope(Global::Get()->resource())); @@ -51,8 +41,6 @@ class TestVirtualMachineScope { ~TestVirtualMachineScope() { virtual_machine_scope_.reset(); test_resource_desc_scope_.reset(); - Global::Delete(); - DestroyNumProcessPerNode(); } private: diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 214d1b06c7a..0d28c9e1fbc 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -140,7 +140,7 @@ class UserOpExprInferContext : public user_op::InferContext { device_tag_(device_tag), tensor_meta4input_index_(TensorMeta4InputIndex), tensor_meta4output_index_(TensorMeta4OutputIndex) {} - ~UserOpExprInferContext() = default; + virtual ~UserOpExprInferContext() override = default; const std::vector>& inputs() const override { return user_op_expr_->indexed_input_pairs(); @@ -259,7 +259,7 @@ class UserOpExprInferContext : public user_op::InferContext { class UserOpExprLogicalInferContext final : public UserOpExprInferContext { public: using UserOpExprInferContext::UserOpExprInferContext; - ~UserOpExprLogicalInferContext() = default; + ~UserOpExprLogicalInferContext() override = default; const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) const override { diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index d42c5b4bd30..6966eeb7689 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include #include "oneflow/core/framework/to_string.h" #include "oneflow/core/framework/op_interpreter.h" @@ -82,6 +83,8 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, for (int i = 0; i < outputs->size(); ++i) { const auto& tensor_impl = JUST(New(output_tensor_metas.at(i), device, parallel_id, false, false)); + const auto& rpc_token = JUST(RpcToken::NewMetaRpcToken()); + JUST(tensor_impl->set_rpc_token(rpc_token)); outputs->at(i).reset(new ConsistentTensor(tensor_impl)); } // Do nothing if the `parallel_desc` doesn't cover current ProcessCtx. diff --git a/oneflow/core/framework/rank_group_rpc_util.cpp b/oneflow/core/framework/rank_group_rpc_util.cpp new file mode 100644 index 00000000000..4ee2e22d969 --- /dev/null +++ b/oneflow/core/framework/rank_group_rpc_util.cpp @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" + +namespace oneflow { + +Maybe CheckRpcToken(Symbol rank_group) { + const auto& rpc_token = + JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdCheckRankGroupConsistency)); + const auto& ctx = std::make_shared( + rpc_token, + [](void** buffer, std::size_t* size, std::function* Callback) -> Maybe { + const auto& placeholder = std::make_shared(); + *buffer = placeholder.get(); + *size = sizeof(uint32_t); + *Callback = [placeholder]() {}; + return Maybe::Ok(); + }); + JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, ctx.get())); + JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, ctx.get())); + return ctx; +} + +Maybe GetCurrentRankGroupLevel() { + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + const auto& root_rank_group = JUST(RankGroupScope::RootRankGroup()); + CHECK_OR_RETURN(rank_group == root_rank_group) << Error::Unimplemented(); + return static_cast(0); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/rank_group_rpc_util.h b/oneflow/core/framework/rank_group_rpc_util.h new file mode 100644 index 00000000000..39f75c3ac79 --- /dev/null +++ b/oneflow/core/framework/rank_group_rpc_util.h @@ -0,0 +1,32 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ +#define ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ + +#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/job/rank_group.h" + +namespace oneflow { + +Maybe CheckRpcToken(Symbol rank_group); + +Maybe GetCurrentRankGroupLevel(); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_PLACEMENT_RPC_UTIL_H_ diff --git a/oneflow/core/framework/rpc_token.cpp b/oneflow/core/framework/rpc_token.cpp new file mode 100644 index 00000000000..77ccadc5dcf --- /dev/null +++ b/oneflow/core/framework/rpc_token.cpp @@ -0,0 +1,285 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" + +namespace oneflow { + +namespace { + +class DataRpcTokenView final { + public: + static Maybe MutCast(RpcToken* rpc_token) { + CHECK_EQ_OR_RETURN(rpc_token->type(), kDataRpcTokenType); + return reinterpret_cast(rpc_token); + } + + void set_data_seq_id(int64_t seq_id) { data_seq_id_ = seq_id; } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint32_t type_ : 2; // RpcTokenType + uint32_t data_seq_id_ : 30; +}; +static_assert(sizeof(DataRpcTokenView) == sizeof(uint64_t), ""); + +class MetaRpcTokenView final { + public: + int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } + int64_t rank_group_level() const { return rank_group_level_; } + + static Maybe MutCast(RpcToken* rpc_token) { + CHECK_EQ_OR_RETURN(rpc_token->type(), kMetaRpcTokenType); + return reinterpret_cast(rpc_token); + } + + static Maybe Cast(const RpcToken* rpc_token) { + CHECK_EQ_OR_RETURN(rpc_token->type(), kMetaRpcTokenType); + return reinterpret_cast(rpc_token); + } + + Maybe set_thread_consistent_unique_id(int8_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kRpcTokenThreadConsistentUIdBit); + thread_consistent_unique_id_ = val; + return Maybe::Ok(); + } + + Maybe set_rank_group_level(int32_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kRpcTokenRankGroupLevelBit); + rank_group_level_ = val; + return Maybe::Ok(); + } + + MetaRpcTokenView& operator++() { + ++low_meta_seq_id_; + if (low_meta_seq_id_ == 0) { ++high_meta_seq_id_; } + return *this; + } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint8_t type_ : 2; // RpcTokenType + uint8_t thread_consistent_unique_id_ : kRpcTokenThreadConsistentUIdBit; + uint8_t rank_group_level_ : kRpcTokenRankGroupLevelBit; + uint8_t high_meta_seq_id_; + uint16_t low_meta_seq_id_; +}; +static_assert(sizeof(MetaRpcTokenView) == sizeof(uint64_t), ""); + +class CtrlRpcTokenView final { + public: + int64_t thread_consistent_unique_id() const { return thread_consistent_unique_id_; } + int64_t rank_group_level() const { return rank_group_level_; } + + static Maybe MutCast(RpcToken* rpc_token) { + CHECK_EQ_OR_RETURN(rpc_token->type(), kCtrlRpcTokenType); + return reinterpret_cast(rpc_token); + } + + static Maybe Cast(const RpcToken* rpc_token) { + CHECK_EQ_OR_RETURN(rpc_token->type(), kCtrlRpcTokenType); + return reinterpret_cast(rpc_token); + } + + Maybe set_thread_consistent_unique_id(int8_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kRpcTokenThreadConsistentUIdBit); + thread_consistent_unique_id_ = val; + return Maybe::Ok(); + } + Maybe set_rank_group_level(int32_t val) { + CHECK_GE_OR_RETURN(val, 0); + CHECK_LT_OR_RETURN(val, 1 << kRpcTokenRankGroupLevelBit); + rank_group_level_ = val; + return Maybe::Ok(); + } + + RankGroupRpcCmd cmd() const { return static_cast(cmd_); } + + void set_cmd(RankGroupRpcCmd cmd) { + static_assert(kSizeOfRankGroupRpcCmd < (1 << 8), ""); + cmd_ = static_cast(cmd); + } + + void set_ctrl_seq_id(int32_t val) { ctrl_seq_id_ = val; } + + private: + uint16_t src_rank_; + uint16_t dst_rank_; + uint8_t type_ : 2; // RpcTokenType + uint8_t thread_consistent_unique_id_ : kRpcTokenThreadConsistentUIdBit; + uint8_t rank_group_level_ : kRpcTokenRankGroupLevelBit; + uint8_t cmd_; + uint16_t ctrl_seq_id_; +}; +static_assert(sizeof(CtrlRpcTokenView) == sizeof(uint64_t), ""); + +} // namespace + +/*static*/ RpcToken RpcToken::NewDataRpcToken() { + static auto* seq_id = new std::atomic(); + RpcToken rpc_token(kDataRpcTokenType); + CHECK_JUST(DataRpcTokenView::MutCast(&rpc_token))->set_data_seq_id(++*seq_id); + return rpc_token; +} + +/*static*/ Maybe RpcToken::NewMetaRpcToken() { + int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); + int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); + static const int kLimit = 128; + CHECK_GE_OR_RETURN(rank_group_level, 0); + CHECK_LT_OR_RETURN(rank_group_level, kLimit); + static thread_local std::array, kLimit> rpc_token_stack; + auto* current_rpc_token = &rpc_token_stack[rank_group_level]; + if (!*current_rpc_token) { + const auto& init = JUST(NewMetaRpcToken(thread_consistent_unique_id, rank_group_level)); + current_rpc_token->reset(new RpcToken(init)); + } + return ++**current_rpc_token; +} + +namespace { + +Maybe ThreadLocalMutLock4CtrlRpcToken(int32_t thread_consistent_unique_id, + int32_t rank_group_level, RankGroupRpcCmd cmd) { + CHECK_EQ_OR_RETURN(thread_consistent_unique_id, JUST(GetThisThreadConsistentUniqueId())); + static const int kRpcTokenRankGroupLevelLimit = (1 << kRpcTokenRankGroupLevelBit); + CHECK_LT_OR_RETURN(rank_group_level, kRpcTokenRankGroupLevelLimit); + static thread_local std::array, + kRpcTokenRankGroupLevelLimit> + rpc_token_lock; + return &rpc_token_lock[rank_group_level][cmd]; +} + +} // namespace + +/*static*/ Maybe RpcToken::AcquireCtrlRpcToken(RankGroupRpcCmd cmd) { + int32_t thread_consistent_unique_id = JUST(GetThisThreadConsistentUniqueId()); + int32_t rank_group_level = JUST(GetCurrentRankGroupLevel()); + auto* lock = + JUST(ThreadLocalMutLock4CtrlRpcToken(thread_consistent_unique_id, rank_group_level, cmd)); + CHECK_OR_RETURN(!*lock); + static const int kRpcTokenRankGroupLevelLimit = (1 << kRpcTokenRankGroupLevelBit); + static thread_local std::array, kSizeOfRankGroupRpcCmd>, + kRpcTokenRankGroupLevelLimit> + rpc_token_stack; + CHECK_GE_OR_RETURN(rank_group_level, 0); + CHECK_LT_OR_RETURN(rank_group_level, kRpcTokenRankGroupLevelLimit); + CHECK_GE_OR_RETURN(static_cast(cmd), 0); + CHECK_LT_OR_RETURN(static_cast(cmd), kSizeOfRankGroupRpcCmd); + auto* current_rpc_token = &rpc_token_stack[rank_group_level][cmd]; + if (!*current_rpc_token) { + const auto& init = JUST(NewCtrlRpcToken(cmd, thread_consistent_unique_id, rank_group_level)); + current_rpc_token->reset(new RpcToken(init)); + } + *lock = true; + return **current_rpc_token; +} + +Maybe RpcToken::ReleaseCtrlRpcToken() const { + auto* lock = JUST(ThreadLocalMutLock4CtrlRpcToken(JUST(thread_consistent_unique_id()), + JUST(rank_group_level()), JUST(cmd()))); + CHECK_OR_RETURN(*lock); + *lock = false; + return Maybe::Ok(); +} + +Maybe RpcToken::thread_consistent_unique_id() const { + if (type() == kMetaRpcTokenType) { + return JUST(MetaRpcTokenView::Cast(this))->thread_consistent_unique_id(); + } else if (type() == kCtrlRpcTokenType) { + return JUST(CtrlRpcTokenView::Cast(this))->thread_consistent_unique_id(); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + UNIMPLEMENTED_THEN_RETURN(); +} + +Maybe RpcToken::rank_group_level() const { + if (type() == kMetaRpcTokenType) { + return JUST(MetaRpcTokenView::Cast(this))->rank_group_level(); + } else if (type() == kCtrlRpcTokenType) { + return JUST(CtrlRpcTokenView::Cast(this))->rank_group_level(); + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + UNIMPLEMENTED_THEN_RETURN(); +} + +Maybe RpcToken::cmd() const { return JUST(CtrlRpcTokenView::Cast(this))->cmd(); } + +Maybe RpcToken::set_src_rank(int64_t src_rank) { + CHECK_GE_OR_RETURN(src_rank, 0); + CHECK_LT_OR_RETURN(src_rank, GetMaxVal()); + src_rank_ = src_rank; + return Maybe::Ok(); +} + +Maybe RpcToken::set_dst_rank(int64_t dst_rank) { + CHECK_GE_OR_RETURN(dst_rank, 0); + CHECK_LT_OR_RETURN(dst_rank, GetMaxVal()); + dst_rank_ = dst_rank; + return Maybe::Ok(); +} + +RpcToken::operator uint64_t() const { + static_assert(sizeof(RpcToken) == sizeof(uint64_t), ""); + return *reinterpret_cast(this); +} + +RpcToken& RpcToken::operator++() { + RpcTokenType rpc_token_type = type(); + if (rpc_token_type == kDataRpcTokenType) { + UNIMPLEMENTED(); + } else if (rpc_token_type == kMetaRpcTokenType) { + ++*CHECK_JUST(MetaRpcTokenView::MutCast(this)); + } else if (rpc_token_type == kCtrlRpcTokenType) { + UNIMPLEMENTED(); + } else { + UNIMPLEMENTED(); + } + return *this; +} + +/*static*/ Maybe RpcToken::NewMetaRpcToken(int32_t thread_consistent_unique_id, + int32_t rank_group_level) { + RpcToken rpc_token(kMetaRpcTokenType); + auto* view = JUST(MetaRpcTokenView::MutCast(&rpc_token)); + JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); + JUST(view->set_rank_group_level(rank_group_level)); + return rpc_token; +} + +/*static*/ Maybe RpcToken::NewCtrlRpcToken(RankGroupRpcCmd cmd, + int32_t thread_consistent_unique_id, + int32_t rank_group_level) { + RpcToken rpc_token(kCtrlRpcTokenType); + auto* view = JUST(CtrlRpcTokenView::MutCast(&rpc_token)); + JUST(view->set_thread_consistent_unique_id(thread_consistent_unique_id)); + JUST(view->set_rank_group_level(rank_group_level)); + view->set_cmd(cmd); + view->set_ctrl_seq_id(0); + return rpc_token; +} + +} // namespace oneflow diff --git a/oneflow/core/framework/rpc_token.h b/oneflow/core/framework/rpc_token.h new file mode 100644 index 00000000000..3b72098d5bb --- /dev/null +++ b/oneflow/core/framework/rpc_token.h @@ -0,0 +1,106 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ +#define ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ + +#include "oneflow/core/common/type_traits.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +const static int kRpcTokenTypeBit = 2; +const static int kRpcTokenThreadConsistentUIdBit = 3; +const static int kRpcTokenRankGroupLevelBit = 3; + +enum RpcTokenType { + // Begin + kDataRpcTokenType = 0, // e.g. for tensor data transportation + kMetaRpcTokenType, // e.g. for tensor meta checking + kCtrlRpcTokenType, // e.g. for rank_group or thread checking. see RankGroupRpcCmd + kExtendedRpcTokenType, // for compatibility + // End + kRpcTokenTypeSize, +}; + +static_assert(kRpcTokenTypeSize <= (1 << kRpcTokenTypeBit), ""); + +enum RankGroupRpcCmd { + // Begin + kRankGroupRpcCmdInvalid = 0, + kRankGroupRpcCmdSyncSymbolParallelDesc, + kRankGroupRpcCmdSyncSymbolParallelDistribution, + kRankGroupRpcCmdSyncSymbolConsistentTensorMeta, + kRankGroupRpcCmdCheckRankGroupConsistency, + kRankGroupRpcCmdCheckTensorConsistency, + // End + kSizeOfRankGroupRpcCmd +}; + +class RpcToken; + +template<> +struct IsScalarType final { + static const bool value = true; +}; + +class RpcToken final { + public: + RpcToken(const RpcToken&) = default; + RpcToken(RpcToken&) = default; + ~RpcToken() = default; + + static RpcToken NewDataRpcToken(); + static Maybe NewMetaRpcToken(); + static Maybe AcquireCtrlRpcToken(RankGroupRpcCmd cmd); + Maybe ReleaseCtrlRpcToken() const; + + static constexpr size_t MaxNumberOfThreadConsistentUId() { + return (1 << kRpcTokenThreadConsistentUIdBit); + } + + // Getters + int64_t src_rank() const { return src_rank_; } + int64_t dst_rank() const { return dst_rank_; } + RpcTokenType type() const { return static_cast(type_); } + Maybe thread_consistent_unique_id() const; + Maybe rank_group_level() const; + Maybe cmd() const; + + // Setters + Maybe set_src_rank(int64_t src_rank); + Maybe set_dst_rank(int64_t dst_rank); + + operator uint64_t() const; + RpcToken& operator++(); + + private: + explicit RpcToken(RpcTokenType type) : type_(type) {} + + static Maybe NewMetaRpcToken(int32_t thread_consistent_unique_id, + int32_t rank_group_level); + static Maybe NewCtrlRpcToken(RankGroupRpcCmd cmd, int32_t thread_consistent_unique_id, + int32_t rank_group_level); + + uint16_t src_rank_; + uint16_t dst_rank_; + uint32_t type_ : 2; // RpcTokenType + uint32_t opaque_ids_ : 30; +}; +static_assert(sizeof(RpcToken) == sizeof(uint64_t), ""); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_RPC_TOKEN_H_ diff --git a/oneflow/core/framework/rpc_util.cpp b/oneflow/core/framework/rpc_util.cpp new file mode 100644 index 00000000000..0d73e427383 --- /dev/null +++ b/oneflow/core/framework/rpc_util.cpp @@ -0,0 +1,134 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include "oneflow/core/framework/rpc_token.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/transport/transport.h" +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" + +namespace oneflow { + +/*static*/ Maybe RpcUtil::WaitUntilDoneOrTimeout(const AsyncRpcCtx& ctx, int64_t seconds) { + const auto& start = std::chrono::steady_clock::now(); + const auto& cond_cnt = ctx.flying_cnt(); + while (*cond_cnt > 0) { + auto end = std::chrono::steady_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + CHECK_LT_OR_RETURN(elapsed_seconds.count(), seconds) + << Error::TimeoutError() << "Timeout error at " << seconds << " seconds."; + } + if (ctx.rpc_token().type() == kCtrlRpcTokenType) { ctx.rpc_token().ReleaseCtrlRpcToken(); } + return Maybe::Ok(); +} + +namespace { + +template (*SendOrRecv)(const RpcToken&, int64_t, void*, std::size_t, + const std::function&)> +Maybe AccessToAllOtherRanks(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx) { + CHECK_OR_RETURN(rank_group->ContainingCurrentRank()); + const auto& flying_cnt = ctx->flying_cnt(); + JUST(rank_group->ForEachRank([&](int64_t rank) -> Maybe { + if (rank == GlobalProcessCtx::Rank()) { return Maybe::Ok(); } + ++*flying_cnt; + void* buffer = nullptr; + std::size_t size = 0; + std::function Callback; + JUST(ctx->MakeDataBufferAndCallback(rank, &buffer, &size, &Callback)); + JUST(SendOrRecv(token, rank, buffer, size, [flying_cnt, Callback]() { + Callback(); + --*flying_cnt; + })); + return Maybe::Ok(); + })); + return Maybe::Ok(); +} + +template (RankGroup::*GetPrevOrNext)() const, + Maybe (*SendOrRecv)(const RpcToken&, int64_t, void*, std::size_t, + const std::function&)> +Maybe AccessToNearbyRank(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx) { + if (rank_group->size() == 1) { return Maybe::Ok(); } + const auto* rank_ranges_ptr = &*rank_group; + int64_t rank = JUST((rank_ranges_ptr->*GetPrevOrNext)()); + CHECK_NE_OR_RETURN(rank, GlobalProcessCtx::Rank()); + const auto& flying_cnt = ctx->flying_cnt(); + ++*flying_cnt; + void* buffer = nullptr; + std::size_t size = 0; + std::function Callback; + JUST(ctx->MakeDataBufferAndCallback(rank, &buffer, &size, &Callback)); + JUST(SendOrRecv(token, rank, buffer, size, [flying_cnt, Callback]() { + Callback(); + --*flying_cnt; + })); + return Maybe::Ok(); +} + +Maybe Send(const RpcToken& token, int64_t rank, void* buffer, std::size_t size, + const std::function& Callback) { + auto* transport = JUST(GlobalMaybe()); + RpcToken transport_token(token); + JUST(transport_token.set_src_rank(GlobalProcessCtx::Rank())); + JUST(transport_token.set_dst_rank(rank)); + transport->Send(static_cast(transport_token), rank, buffer, size, Callback); + return Maybe::Ok(); +} + +Maybe Recv(const RpcToken& token, int64_t rank, void* buffer, std::size_t size, + const std::function& Callback) { + auto* transport = JUST(GlobalMaybe()); + RpcToken transport_token(token); + JUST(transport_token.set_src_rank(rank)); + JUST(transport_token.set_dst_rank(GlobalProcessCtx::Rank())); + transport->Receive(static_cast(transport_token), rank, buffer, size, Callback); + return Maybe::Ok(); +} + +} // namespace + +/*static*/ Maybe RpcUtil::BroadcastToAllOtherRanks(Symbol rank_group, + const RpcToken& token, AsyncRpcCtx* ctx) { + JUST(AccessToAllOtherRanks<&Send>(rank_group, token, ctx)); + return Maybe::Ok(); +} + +/*static*/ Maybe RpcUtil::CollectFromAllOtherRanks(Symbol rank_group, + const RpcToken& token, AsyncRpcCtx* ctx) { + JUST(AccessToAllOtherRanks<&Recv>(rank_group, token, ctx)); + return Maybe::Ok(); +} + +/*static*/ Maybe RpcUtil::SendToNextRankInRing(Symbol rank_group, + const RpcToken& token, AsyncRpcCtx* ctx) { + JUST(AccessToNearbyRank<&RankGroup::GetNextRankInRing, &Send>(rank_group, token, ctx)); + return Maybe::Ok(); +} + +/*static*/ Maybe RpcUtil::ReceiveFromPrevRankInRing(Symbol rank_group, + const RpcToken& token, AsyncRpcCtx* ctx) { + JUST(AccessToNearbyRank<&RankGroup::GetPrevRankInRing, &Recv>(rank_group, token, ctx)); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/rpc_util.h b/oneflow/core/framework/rpc_util.h new file mode 100644 index 00000000000..74bf84c2155 --- /dev/null +++ b/oneflow/core/framework/rpc_util.h @@ -0,0 +1,82 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ +#define ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ + +#include +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/rpc_token.h" + +namespace oneflow { + +class AsyncRpcCtx { + public: + explicit AsyncRpcCtx(const RpcToken& rpc_token) + : rpc_token_(rpc_token), flying_cnt_(new std::atomic(0)) {} + virtual ~AsyncRpcCtx() = default; + + const RpcToken& rpc_token() const { return rpc_token_; } + std::shared_ptr> flying_cnt() const { return flying_cnt_; } + + virtual Maybe MakeDataBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) = 0; + + private: + RpcToken rpc_token_; + std::shared_ptr> flying_cnt_; +}; + +class NaiveAsyncRpcCtx final : public AsyncRpcCtx { + public: + explicit NaiveAsyncRpcCtx( + const RpcToken& rpc_token, + const std::function(void**, std::size_t*, std::function*)>& Prepare) + : AsyncRpcCtx(rpc_token), prepare_(Prepare) {} + ~NaiveAsyncRpcCtx() override = default; + + Maybe MakeDataBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override { + return prepare_(buffer, size, Callback); + } + + private: + std::function(void**, std::size_t*, std::function*)> prepare_; +}; + +class RankGroup; + +struct RpcUtil final { + static int64_t TimeoutSeconds() { return 60 * 5; } + + static Maybe WaitUntilDoneOrTimeout(const AsyncRpcCtx& ctx, int64_t seconds); + + static Maybe SendToNextRankInRing(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx); + + static Maybe ReceiveFromPrevRankInRing(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx); + + static Maybe BroadcastToAllOtherRanks(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx); + + static Maybe CollectFromAllOtherRanks(Symbol rank_group, const RpcToken& token, + AsyncRpcCtx* ctx); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_RPC_UTIL_H_ diff --git a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp new file mode 100644 index 00000000000..82c4fbbe59e --- /dev/null +++ b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.cpp @@ -0,0 +1,100 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/sync_symbol_consistent_tensor_meta.h" +#include "oneflow/core/framework/sync_symbol_parallel_desc.h" +#include "oneflow/core/framework/sync_symbol_parallel_distribution.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/framework/synced_symbol_map.h" +#include "oneflow/core/common/flat_shape.h" + +namespace oneflow { + +struct FlatConsistentTensorMeta final { + static Maybe New( + uint64_t symbol_id, Symbol consistent_tensor_meta) { + const auto& meta = std::make_shared(); + JUST(meta->Init(symbol_id, consistent_tensor_meta)); + return meta; + } + + Maybe Init(uint64_t symbol_id, Symbol consistent_tensor_meta) { + this->symbol_id = symbol_id; + JUST(this->shape.Init(consistent_tensor_meta->shape_ptr())); + this->dtype = static_cast(consistent_tensor_meta->dtype()); + this->is_dynamic = consistent_tensor_meta->is_dynamic(); + this->parallel_distribution = JUST(SyncedSymbolMap::FindOrSync( + consistent_tensor_meta->parallel_distribution(), &SyncSymbolParallelDistribution)); + this->parallel_desc = JUST(SyncedSymbolMap::FindOrSync( + consistent_tensor_meta->parallel_desc(), &SyncSymbolParallelDesc)); + return Maybe::Ok(); + } + + Maybe Check(uint64_t symbol_id, Symbol consistent_tensor_meta) { + CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id); + JUST(this->shape.Check(consistent_tensor_meta->shape_ptr())); + CHECK_EQ_OR_RETURN(static_cast(this->dtype), consistent_tensor_meta->dtype()); + CHECK_EQ_OR_RETURN(this->is_dynamic, consistent_tensor_meta->is_dynamic()); + const auto& parallel_distribution = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->parallel_distribution)); + CHECK_OR_RETURN(parallel_distribution == consistent_tensor_meta->parallel_distribution()); + const auto& parallel_desc = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId(this->parallel_desc)); + CHECK_OR_RETURN(parallel_desc == consistent_tensor_meta->parallel_desc()); + return Maybe::Ok(); + } + + uint64_t symbol_id; + FlatShape shape; + int32_t dtype; + bool is_dynamic; + uint64_t parallel_distribution; + uint64_t parallel_desc; +}; + +Maybe SyncSymbolConsistentTensorMeta( + uint64_t symbol_id, Symbol consistent_tensor_meta) { + const auto& rpc_token = + JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolConsistentTensorMeta)); + NaiveAsyncRpcCtx send_ctx( + rpc_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + const auto& send_buffer = + JUST(FlatConsistentTensorMeta::New(symbol_id, consistent_tensor_meta)); + *buffer = send_buffer.get(); + *size = sizeof(FlatConsistentTensorMeta); + *Cb = [send_buffer] {}; + return Maybe::Ok(); + }); + const auto& recv_buffer = std::make_shared(); + NaiveAsyncRpcCtx recv_ctx( + rpc_token, + [recv_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + *buffer = recv_buffer.get(); + *size = sizeof(FlatConsistentTensorMeta); + *Cb = [recv_buffer] {}; + return Maybe::Ok(); + }); + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &send_ctx)); + JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &recv_ctx)); + JUST(RpcUtil::WaitUntilDoneOrTimeout(send_ctx, RpcUtil::TimeoutSeconds())); + JUST(RpcUtil::WaitUntilDoneOrTimeout(recv_ctx, RpcUtil::TimeoutSeconds())); + JUST(recv_buffer->Check(symbol_id, consistent_tensor_meta)); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h new file mode 100644 index 00000000000..be61e3dd0b9 --- /dev/null +++ b/oneflow/core/framework/sync_symbol_consistent_tensor_meta.h @@ -0,0 +1,34 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_CONSISTENT_TENSOR_META_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_CONSISTENT_TENSOR_META_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/rpc_token.h" + +namespace oneflow { + +namespace one { +class ConsistentTensorMeta; +} + +Maybe SyncSymbolConsistentTensorMeta(uint64_t symbol_id, Symbol); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_CONSISTENT_TENSOR_META_H_ diff --git a/oneflow/core/framework/sync_symbol_parallel_desc.cpp b/oneflow/core/framework/sync_symbol_parallel_desc.cpp new file mode 100644 index 00000000000..579aa0f28c5 --- /dev/null +++ b/oneflow/core/framework/sync_symbol_parallel_desc.cpp @@ -0,0 +1,99 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/sync_symbol_parallel_desc.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/job/parallel_desc.h" + +namespace oneflow { + +namespace { + +static const int kLimitParallelConfString = 1024 * 8; +struct FlatParallelConf { + size_t available_size() const { + CHECK_GE(this->buffer_size, 0); + CHECK_LT(this->buffer_size, kLimitParallelConfString); + return sizeof(FlatParallelConf) - kLimitParallelConfString + this->buffer_size; + } + + size_t capacity() const { return sizeof(FlatParallelConf); } + + static Maybe New(uint64_t symbol_id, Symbol parallel_desc) { + const auto& data = std::make_shared(); + JUST(data->Init(symbol_id, parallel_desc)); + return data; + } + + Maybe Init(uint64_t symbol_id, Symbol parallel_desc) { + const auto& parallel_conf = parallel_desc->parallel_conf(); + int64_t byte_size = parallel_conf.ByteSize(); + CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString); + this->symbol_id = symbol_id; + this->buffer_size = byte_size; + CHECK_OR_RETURN(parallel_conf.SerializeToArray(this->buffer, kLimitParallelConfString)); + return Maybe::Ok(); + } + + Maybe Check(uint64_t symbol_id, Symbol parallel_desc) const { + const auto& parallel_conf = parallel_desc->parallel_conf(); + int64_t byte_size = parallel_conf.ByteSize(); + CHECK_LE_OR_RETURN(byte_size, kLimitParallelConfString); + CHECK_EQ_OR_RETURN(this->symbol_id, symbol_id); + CHECK_EQ_OR_RETURN(this->buffer_size, byte_size); + std::vector serialized(byte_size); + CHECK_OR_RETURN(parallel_conf.SerializeToArray(serialized.data(), kLimitParallelConfString)); + CHECK_EQ_OR_RETURN(std::memcmp(serialized.data(), this->buffer, byte_size), 0); + return Maybe::Ok(); + } + + uint64_t symbol_id; + uint64_t buffer_size; + char buffer[kLimitParallelConfString]; +}; + +} // namespace + +Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol parallel_desc) { + const auto& rpc_token = + JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolParallelDesc)); + NaiveAsyncRpcCtx send_ctx( + rpc_token, [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + const auto& send_buffer = JUST(FlatParallelConf::New(symbol_id, parallel_desc)); + *buffer = send_buffer.get(); + *size = send_buffer->available_size(); + *Cb = [send_buffer] {}; + return Maybe::Ok(); + }); + const auto& recv_buffer = std::make_shared(); + NaiveAsyncRpcCtx recv_ctx( + rpc_token, + [recv_buffer](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + *buffer = recv_buffer.get(); + *size = recv_buffer->capacity(); + *Cb = [recv_buffer] {}; + return Maybe::Ok(); + }); + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &send_ctx)); + JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &recv_ctx)); + JUST(RpcUtil::WaitUntilDoneOrTimeout(send_ctx, RpcUtil::TimeoutSeconds())); + JUST(RpcUtil::WaitUntilDoneOrTimeout(recv_ctx, RpcUtil::TimeoutSeconds())); + JUST(recv_buffer->Check(symbol_id, parallel_desc)); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/sync_symbol_parallel_desc.h b/oneflow/core/framework/sync_symbol_parallel_desc.h new file mode 100644 index 00000000000..02309e7b30c --- /dev/null +++ b/oneflow/core/framework/sync_symbol_parallel_desc.h @@ -0,0 +1,32 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/rpc_token.h" + +namespace oneflow { + +class ParallelDesc; + +Maybe SyncSymbolParallelDesc(uint64_t symbol_id, Symbol); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DESC_H_ diff --git a/oneflow/core/framework/sync_symbol_parallel_distribution.cpp b/oneflow/core/framework/sync_symbol_parallel_distribution.cpp new file mode 100644 index 00000000000..8adf9d8038c --- /dev/null +++ b/oneflow/core/framework/sync_symbol_parallel_distribution.cpp @@ -0,0 +1,173 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/object_msg/flat_msg.h" +#include "oneflow/core/framework/sync_symbol_parallel_distribution.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/job/sbp_parallel.cfg.h" +#include "oneflow/core/common/shape_vec.h" + +namespace oneflow { + +namespace { + +FLAT_MSG_BEGIN(FlatSplitParallel); +FLAT_MSG_DEFINE_OPTIONAL(int64_t, axis); +FLAT_MSG_END(FlatSplitParallel); + +FLAT_MSG_BEGIN(FlatBroadcastParallel); +FLAT_MSG_END(FlatBroadcastParallel); + +FLAT_MSG_BEGIN(FlatPartialSumParallel); +FLAT_MSG_END(FlatPartialSumParallel); + +FLAT_MSG_BEGIN(FlatSbpParallel); +Maybe Init(const cfg::SbpParallel& sbp_parallel) { + if (sbp_parallel.has_split_parallel()) { + this->mutable_split_parallel()->set_axis(sbp_parallel.split_parallel().axis()); + } else if (sbp_parallel.has_broadcast_parallel()) { + this->mutable_broadcast_parallel(); + } else if (sbp_parallel.has_partial_sum_parallel()) { + this->mutable_partial_sum_parallel(); + } else { + OF_UNIMPLEMENTED(); + } + return Maybe::Ok(); +} + +Maybe Check(const cfg::SbpParallel& sbp_parallel) const { + if (sbp_parallel.has_split_parallel()) { + CHECK_EQ_OR_RETURN(this->split_parallel().axis(), sbp_parallel.split_parallel().axis()); + } else if (sbp_parallel.has_broadcast_parallel()) { + CHECK_OR_RETURN(this->has_broadcast_parallel()); + } else if (sbp_parallel.has_partial_sum_parallel()) { + CHECK_OR_RETURN(this->has_partial_sum_parallel()); + } else { + OF_UNIMPLEMENTED(); + } + return Maybe::Ok(); +} + +FLAT_MSG_DEFINE_ONEOF(parallel_type, + FLAT_MSG_ONEOF_FIELD(FlatSplitParallel, split_parallel) + FLAT_MSG_ONEOF_FIELD(FlatBroadcastParallel, broadcast_parallel) + FLAT_MSG_ONEOF_FIELD(FlatPartialSumParallel, partial_sum_parallel)); +FLAT_MSG_END(FlatSbpParallel); + +FLAT_MSG_BEGIN(FlatParallelDistribution); +OF_PUBLIC Maybe Init(uint64_t symbol_id, + Symbol parallel_distribution) { + this->set_symbol_id(symbol_id); + this->set_size(parallel_distribution->sbp_parallel_size()); + for (int i = 0; i < this->size(); ++i) { + const auto& sbp_parallel = parallel_distribution->sbp_parallel(i); + JUST(this->mutable_sbp_parallel()->Mutable(i)->Init(sbp_parallel)); + } + return Maybe::Ok(); +} + +OF_PUBLIC Maybe Check(uint64_t symbol_id, + Symbol parallel_distribution) const { + CHECK_EQ_OR_RETURN(this->symbol_id(), symbol_id); + CHECK_EQ_OR_RETURN(this->size(), parallel_distribution->sbp_parallel_size()); + for (int i = 0; i < this->size(); ++i) { + JUST(this->sbp_parallel().Get(i).Check(parallel_distribution->sbp_parallel(i))); + } + return Maybe::Ok(); +} + +FLAT_MSG_DEFINE_OPTIONAL(uint64_t, symbol_id); +FLAT_MSG_DEFINE_OPTIONAL(size_t, size); +FLAT_MSG_DEFINE_REPEATED(FlatSbpParallel, sbp_parallel, SHAPE_MAX_AXIS_SIZE); +FLAT_MSG_END(FlatParallelDistribution); + +class SendFlatParallelDistributionAsyncRpcCtx : public AsyncRpcCtx { + public: + SendFlatParallelDistributionAsyncRpcCtx(const RpcToken& rpc_token, uint64_t symbol_id, + Symbol parallel_distribution) + : AsyncRpcCtx(rpc_token), + symbol_id_(symbol_id), + parallel_distribution_(parallel_distribution) {} + + ~SendFlatParallelDistributionAsyncRpcCtx() override {} + + Maybe MakeDataBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override { + const auto& flat_parallel_distribution = std::make_shared(); + JUST(flat_parallel_distribution->Init(symbol_id_, parallel_distribution_)); + *buffer = flat_parallel_distribution.get(); + *size = sizeof(FlatParallelDistribution); + *Callback = [flat_parallel_distribution]() {}; + return Maybe::Ok(); + } + + private: + uint64_t symbol_id_; + Symbol parallel_distribution_; +}; + +class RecvFlatParallelDistributionAsyncRpcCtx : public AsyncRpcCtx { + public: + RecvFlatParallelDistributionAsyncRpcCtx(const RpcToken& rpc_token, uint64_t symbol_id, + Symbol parallel_distribution) + : AsyncRpcCtx(rpc_token), + symbol_id_(symbol_id), + parallel_distribution_(parallel_distribution) {} + + ~RecvFlatParallelDistributionAsyncRpcCtx() override {} + + Maybe MakeDataBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override { + const auto& flat_parallel_distribution = std::make_shared(); + *buffer = flat_parallel_distribution.get(); + *size = sizeof(FlatParallelDistribution); + *Callback = [flat_parallel_distribution]() {}; + flat_parallel_distribution_ = flat_parallel_distribution; + return Maybe::Ok(); + } + + Maybe Check() const { + CHECK_NOTNULL_OR_RETURN(flat_parallel_distribution_.get()); + JUST(flat_parallel_distribution_->Check(symbol_id_, parallel_distribution_)); + return Maybe::Ok(); + } + + private: + uint64_t symbol_id_; + Symbol parallel_distribution_; + std::shared_ptr flat_parallel_distribution_; +}; + +} // namespace + +namespace {} + +Maybe SyncSymbolParallelDistribution(uint64_t symbol_id, + Symbol symbol) { + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + const auto& rpc_token = + JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdSyncSymbolParallelDistribution)); + SendFlatParallelDistributionAsyncRpcCtx send_ctx(rpc_token, symbol_id, symbol); + RecvFlatParallelDistributionAsyncRpcCtx recv_ctx(rpc_token, symbol_id, symbol); + JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &send_ctx)); + JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, &recv_ctx)); + JUST(RpcUtil::WaitUntilDoneOrTimeout(send_ctx, RpcUtil::TimeoutSeconds())); + JUST(RpcUtil::WaitUntilDoneOrTimeout(recv_ctx, RpcUtil::TimeoutSeconds())); + JUST(recv_ctx.Check()); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/sync_symbol_parallel_distribution.h b/oneflow/core/framework/sync_symbol_parallel_distribution.h new file mode 100644 index 00000000000..d14cce93b87 --- /dev/null +++ b/oneflow/core/framework/sync_symbol_parallel_distribution.h @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DISTRIBUTION_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DISTRIBUTION_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/rpc_token.h" + +namespace oneflow { + +namespace cfg { + +class ParallelDistribution; + +} + +Maybe SyncSymbolParallelDistribution(uint64_t symbol_id, Symbol); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYNC_SYMBOL_PARALLEL_DISTRIBUTION_H_ diff --git a/oneflow/core/framework/synced_symbol_map.cpp b/oneflow/core/framework/synced_symbol_map.cpp new file mode 100644 index 00000000000..f53598ff8ad --- /dev/null +++ b/oneflow/core/framework/synced_symbol_map.cpp @@ -0,0 +1,25 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/synced_symbol_map.h" + +namespace oneflow { + +uint64_t GetAutoIncrementalSymbolId() { + static thread_local uint64_t id = 4096; + return id++; +} + +} // namespace oneflow diff --git a/oneflow/core/framework/synced_symbol_map.h b/oneflow/core/framework/synced_symbol_map.h new file mode 100644 index 00000000000..0cf6f0fa155 --- /dev/null +++ b/oneflow/core/framework/synced_symbol_map.h @@ -0,0 +1,74 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ + +#include +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/type_traits.h" +#include "oneflow/core/job/rank_group_scope.h" + +namespace oneflow { + +uint64_t GetAutoIncrementalSymbolId(); + +template +struct SyncedSymbolMap final { + template + static Maybe FindOrSync(Symbol symbol, const SyncT& Sync) { + auto* map = JUST(MutThreadLocalSymbol2SyncedSymbolId()); + const auto& iter = map->find(symbol); + if (iter != map->end()) { return iter->second; } + uint64_t symbol_id = GetAutoIncrementalSymbolId(); + JUST(Sync(symbol_id, symbol)); + JUST(Emplace(symbol_id, symbol)); + return symbol_id; + } + + static Maybe> Symbol4SyncedSymbolId(uint64_t synced_symbol_id) { + auto* map = JUST(MutThreadLocalSyncedSymbolId2Symbol()); + return JUST(MapAt(*map, synced_symbol_id)); + } + + private: + static Maybe Emplace(uint64_t synced_symbol_id, Symbol symbol) { + auto* id2symbol = JUST(MutThreadLocalSyncedSymbolId2Symbol()); + CHECK_OR_RETURN(id2symbol->emplace(synced_symbol_id, symbol).second); + auto* symbol2id = JUST(MutThreadLocalSymbol2SyncedSymbolId()); + CHECK_OR_RETURN(symbol2id->emplace(symbol, synced_symbol_id).second); + return Maybe::Ok(); + } + + static Maybe>*> MutThreadLocalSyncedSymbolId2Symbol() { + static thread_local auto* map = + new std::unordered_map, std::unordered_map>>(); + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + return &(*map)[rank_group]; + } + + static Maybe, uint64_t>*> MutThreadLocalSymbol2SyncedSymbolId() { + static thread_local auto* map = + new std::unordered_map, std::unordered_map, uint64_t>>(); + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + return &(*map)[rank_group]; + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYNCED_SYMBOL_MAP_H_ diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 5024dbde7c8..2bcb9024f05 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.h" +#include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index 557922e1a61..3c84e8d4402 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/memory/memory_case.pb.h" #include "oneflow/core/framework/tensor_impl.h" +#include "oneflow/core/framework/rpc_token.h" #include "oneflow/core/common/error.h" namespace oneflow { @@ -48,6 +49,7 @@ class Tensor { virtual const std::shared_ptr& shape() const = 0; virtual DataType dtype() const = 0; + virtual Maybe rpc_token() const = 0; virtual Maybe> parallel_distribution() const = 0; virtual Maybe> parallel_desc() const = 0; virtual Maybe> device() const = 0; @@ -71,8 +73,8 @@ class Tensor { virtual Maybe storage_offset() const { OF_UNIMPLEMENTED(); } // Getters/Setters valid only for EagerConsistentTensor - virtual Maybe> consumer_parallel_distribution_constraint() - const { + virtual Maybe>&> + consumer_parallel_distribution_constraint() const { OF_UNIMPLEMENTED(); } virtual Maybe cur_rank_phy_tensor() const { OF_UNIMPLEMENTED(); } @@ -178,10 +180,11 @@ class Parameter final : public TensorIf { Maybe stride() const override { return tensor_->stride(); } Maybe storage_offset() const override { return tensor_->storage_offset(); } - Maybe> consumer_parallel_distribution_constraint() - const override { + Maybe>&> + consumer_parallel_distribution_constraint() const override { return tensor_->consumer_parallel_distribution_constraint(); } + Maybe rpc_token() const override { return tensor_->rpc_token(); } Maybe cur_rank_phy_tensor() const override { return tensor_->cur_rank_phy_tensor(); } @@ -245,6 +248,7 @@ class MirroredTensor final : public TensorIf, // Getters const std::shared_ptr& shape() const override { return impl_->shape(); } DataType dtype() const override { return impl_->dtype(); } + Maybe rpc_token() const override { OF_UNIMPLEMENTED(); } Maybe> parallel_distribution() const override { OF_UNIMPLEMENTED(); } @@ -328,6 +332,7 @@ class ConsistentTensor final : public TensorIf { // Getters const std::shared_ptr& shape() const override { return impl_->shape(); } DataType dtype() const override { return impl_->dtype(); } + Maybe rpc_token() const override { return impl_->rpc_token(); } Maybe> parallel_distribution() const override { return impl_->parallel_distribution(); } @@ -336,8 +341,8 @@ class ConsistentTensor final : public TensorIf { Maybe*> mut_device() override { OF_UNIMPLEMENTED(); } bool is_lazy() const override { return impl_->is_lazy(); } bool is_consistent() const override { return true; } - Maybe> consumer_parallel_distribution_constraint() - const override { + Maybe>&> + consumer_parallel_distribution_constraint() const override { return impl_->consumer_parallel_distribution_constraint(); } Maybe cur_rank_phy_tensor() const override { diff --git a/oneflow/core/framework/tensor_impl.h b/oneflow/core/framework/tensor_impl.h index 1a713e7aaa9..006efee4a83 100644 --- a/oneflow/core/framework/tensor_impl.h +++ b/oneflow/core/framework/tensor_impl.h @@ -19,11 +19,13 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/optional.h" #include "oneflow/core/job/placement.cfg.h" #include "oneflow/core/framework/object.h" #include "oneflow/core/framework/tensor_storage.h" #include "oneflow/core/framework/tensor_desc.h" #include "oneflow/core/framework/tensor_meta.h" +#include "oneflow/core/framework/rpc_token.h" #include "oneflow/core/autograd/autograd_meta.h" #include "oneflow/core/common/symbol.h" @@ -134,7 +136,8 @@ class ConsistentTensorImpl : public TensorImpl { return tensor_meta_->parallel_distribution(); } Symbol parallel_desc() const { return tensor_meta_->parallel_desc(); } - Symbol consumer_parallel_distribution_constraint() const { + const Optional>& consumer_parallel_distribution_constraint() + const { return consumer_parallel_distribution_constraint_; } virtual Maybe cur_rank_phy_tensor() const { OF_UNIMPLEMENTED(); } @@ -155,14 +158,24 @@ class ConsistentTensorImpl : public TensorImpl { return nullptr; } + const Maybe rpc_token() const { return rpc_token_; } + + Maybe set_rpc_token(const RpcToken& rpc_token) { + CHECK_OR_RETURN(!rpc_token_.IsOk()) << "rpc_token_ is initiliazed"; + rpc_token_ = rpc_token; + return Maybe::Ok(); + } + protected: ConsistentTensorImpl(Symbol tensor_meta, bool requires_grad, bool is_leaf) : TensorImpl(requires_grad, is_leaf), tensor_meta_(tensor_meta), - consumer_parallel_distribution_constraint_() {} + consumer_parallel_distribution_constraint_(), + rpc_token_(Error::ValueError("invalid rpc token")) {} Symbol tensor_meta_; - Symbol consumer_parallel_distribution_constraint_; + Optional> consumer_parallel_distribution_constraint_; + Maybe rpc_token_; }; class LazyMirroredTensorImpl final : public MirroredTensorImpl { diff --git a/oneflow/core/framework/tensor_rpc_util.cpp b/oneflow/core/framework/tensor_rpc_util.cpp new file mode 100644 index 00000000000..a0e3b0f4c92 --- /dev/null +++ b/oneflow/core/framework/tensor_rpc_util.cpp @@ -0,0 +1,149 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/framework/tensor_rpc_util.h" +#include "oneflow/core/framework/sync_symbol_consistent_tensor_meta.h" +#include "oneflow/core/framework/sync_symbol_parallel_distribution.h" +#include "oneflow/core/framework/synced_symbol_map.h" +#include "oneflow/core/framework/rank_group_rpc_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/common/flat_shape.h" +#include "oneflow/core/common/shape_vec.h" +#include "oneflow/core/object_msg/flat_msg.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/job/rank_group_scope.h" + +namespace oneflow { + +// clang-format off +FLAT_MSG_BEGIN(FlatTensorConsistency); + OF_PUBLIC static Maybe New() { + const auto& consistency = std::make_shared(); + consistency->clear(); + return consistency; + } + OF_PUBLIC static Maybe New( + Symbol tensor_meta, + const Optional> consumer_parallel_distribution_constraint, + const RpcToken& tensor_rpc_token) { + const auto& consistency = std::make_shared(); + consistency->clear(); + JUST(consistency->Init(tensor_meta, consumer_parallel_distribution_constraint, tensor_rpc_token)); + return consistency; + } + + OF_PUBLIC Maybe Check(Symbol tensor_meta, + const Optional> consumer_parallel_distribution_constraint, + const RpcToken& tensor_rpc_token) { + const auto& this_synced_tensor_meta = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->synced_tensor_meta_symbol_id())); + CHECK_OR_RETURN(this_synced_tensor_meta == tensor_meta); + CHECK_EQ_OR_RETURN(consumer_parallel_distribution_constraint.has_value(), + this->has_consumer_parallel_distribution_constraint_symbol_id()); + if (this->has_consumer_parallel_distribution_constraint_symbol_id()) { + const auto& that_rank_constaint = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->consumer_parallel_distribution_constraint_symbol_id())); + const auto& this_rank_constaint = JUST(consumer_parallel_distribution_constraint.value()); + CHECK_OR_RETURN(this_rank_constaint == that_rank_constaint); + } + CHECK_EQ_OR_RETURN(this->tensor_rpc_token(), tensor_rpc_token); + return Maybe::Ok(); + } + + OF_PRIVATE Maybe Init(Symbol tensor_meta, + const Optional> consumer_parallel_distribution_constraint, + const RpcToken& tensor_rpc_token) { + this->set_synced_tensor_meta_symbol_id(JUST(SyncedSymbolMap::FindOrSync( + tensor_meta, &SyncSymbolConsistentTensorMeta))); + if (consumer_parallel_distribution_constraint.has_value()) { + const auto& this_rank_constaint = JUST(consumer_parallel_distribution_constraint.value()); + this->set_consumer_parallel_distribution_constraint_symbol_id( + JUST(SyncedSymbolMap::FindOrSync( + this_rank_constaint, &SyncSymbolParallelDistribution))); + } else { + this->clear_consumer_parallel_distribution_constraint_symbol_id(); + } + this->set_tensor_rpc_token(static_cast(tensor_rpc_token)); + return Maybe::Ok(); + } + + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, synced_tensor_meta_symbol_id); + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, consumer_parallel_distribution_constraint_symbol_id); + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, tensor_rpc_token); +FLAT_MSG_END(FlatTensorConsistency); +// clang-format off + +CheckConsistencyAsyncRpcCtx::~CheckConsistencyAsyncRpcCtx() {} + +Maybe CheckConsistencyAsyncRpcCtx::MakeDataBufferAndCallback( + int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { + const auto& flat_tensor_consistency = JUST(FlatTensorConsistency::New()); + *buffer = flat_tensor_consistency.get(); + *size = sizeof(FlatTensorConsistency); + *Callback = [flat_tensor_consistency]() {}; + flat_tensor_consistency_ = flat_tensor_consistency; + return Maybe::Ok(); +} + +Maybe CheckConsistencyAsyncRpcCtx::Check() const { + JUST(flat_tensor_consistency_->Check(tensor_meta_, consumer_parallel_distribution_constraint_, tensor_rpc_token_)); + return Maybe::Ok(); +} + +namespace { + +Maybe SendTensorMetaToNextRankInRing(const one::Tensor& tensor, Symbol rank_group, + const RpcToken& rpc_token) { + const auto& tensor_meta = JUST(tensor.consistent_tensor_meta()); + const auto& constaint = JUST(tensor.consumer_parallel_distribution_constraint()); + const RpcToken& tensor_rpc_token = JUST(tensor.rpc_token()); + NaiveAsyncRpcCtx ctx( + rpc_token, [&](void** buffer, std::size_t* size, std::function* Callback) -> Maybe { + const auto& tensor_consistency = + JUST(FlatTensorConsistency::New(tensor_meta, constaint, tensor_rpc_token)); + *buffer = tensor_consistency.get(); + *size = sizeof(FlatTensorConsistency); + *Callback = [tensor_consistency] {}; + return Maybe::Ok(); + }); + JUST(RpcUtil::SendToNextRankInRing(rank_group, rpc_token, &ctx)); + return Maybe::Ok(); +} + +Maybe ReceiveTensorMetaFromPrevRankInRing(const one::Tensor& tensor, + Symbol rank_group, + const RpcToken& rpc_token) { + const auto& tensor_meta = JUST(tensor.consistent_tensor_meta()); + const auto& constaint = JUST(tensor.consumer_parallel_distribution_constraint()); + const RpcToken& tensor_rpc_token = JUST(tensor.rpc_token()); + const auto& ctx = std::make_shared(rpc_token, tensor_meta, constaint, tensor_rpc_token); + JUST(RpcUtil::ReceiveFromPrevRankInRing(rank_group, rpc_token, ctx.get())); + return ctx; +} + +} // namespace + +Maybe LaunchTensorMetaConsistencyCheck(const one::Tensor& tensor) { + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + const auto& rpc_token = + JUST(RpcToken::AcquireCtrlRpcToken(kRankGroupRpcCmdCheckTensorConsistency)); + JUST(SendTensorMetaToNextRankInRing(tensor, rank_group, rpc_token)); + return ReceiveTensorMetaFromPrevRankInRing(tensor, rank_group, rpc_token); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/tensor_rpc_util.h b/oneflow/core/framework/tensor_rpc_util.h new file mode 100644 index 00000000000..cced9982585 --- /dev/null +++ b/oneflow/core/framework/tensor_rpc_util.h @@ -0,0 +1,56 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ +#define ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ + +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/common/optional.h" + +namespace oneflow { + +class FlatTensorConsistency; + +class CheckConsistencyAsyncRpcCtx : public AsyncRpcCtx { + public: + CheckConsistencyAsyncRpcCtx( + const RpcToken& rpc_token, Symbol tensor_meta, + const Optional>& consumer_parallel_distribution_constraint, + const RpcToken& tensor_rpc_token) + : AsyncRpcCtx(rpc_token), + tensor_meta_(tensor_meta), + consumer_parallel_distribution_constraint_(consumer_parallel_distribution_constraint), + tensor_rpc_token_(tensor_rpc_token) {} + + ~CheckConsistencyAsyncRpcCtx() override; + + Maybe MakeDataBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override; + + Maybe Check() const; + + private: + Symbol tensor_meta_; + Optional> consumer_parallel_distribution_constraint_; + RpcToken tensor_rpc_token_; + std::shared_ptr flat_tensor_consistency_; +}; + +Maybe LaunchTensorMetaConsistencyCheck(const one::Tensor& tensor); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_TENSOR_RPC_UTIL_H_ diff --git a/oneflow/core/job/id_manager_test.cpp b/oneflow/core/job/id_manager_test.cpp index 91c26ba16eb..0669f721475 100644 --- a/oneflow/core/job/id_manager_test.cpp +++ b/oneflow/core/job/id_manager_test.cpp @@ -47,14 +47,17 @@ Resource GetResource() { void New() { Global::New(GetEnvProto()); - Global::New()->set_value(1); + Global::New(); + Global::Get()->mutable_ctrl_addr()->Add(); + Global::Get()->set_rank(0); + Global::Get()->set_node_size(1); Global::New(GetResource(), GlobalProcessCtx::NumOfProcessPerNode()); Global::New(); } void Delete() { Global::Delete(); - Global::Delete(); + Global::Delete(); Global::Delete(); Global::Delete(); } diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index 81a60f16205..3d096a61a66 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -111,6 +111,7 @@ Maybe ParallelDesc::MaybeInit(const ParallelConf& user_conf) { GlobalProcessCtx::NumOfProcessPerNode())); } } + containing_current_rank_ = machine_id2sorted_dev_phy_ids_->count(GlobalProcessCtx::Rank()) > 0; ClearUp(); JUST(SanityCheck()); return Maybe::Ok(); diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 221ea8ec95a..363758ea77b 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -56,6 +56,7 @@ class ParallelDesc final { // Getters const Maybe& symbol_id() const { return symbol_id_; } + bool containing_current_rank() const { return containing_current_rank_; } DeviceType device_type() const { return device_type_; } const std::string& device_tag() const { return parallel_conf_.device_tag(); } std::shared_ptr>>> @@ -129,6 +130,8 @@ class ParallelDesc final { // TODO(lixinqi): merge cfg_parallel_conf_ and parallel_conf_ after cfg::ParallelConf taken as the // constructor argument std::shared_ptr cfg_parallel_conf_; + // cached result of ContainingMachineId(GlobalProcessCtx::Rank()) for performace optimization. + bool containing_current_rank_; }; Maybe> GetDevice4CurrentProcessCtx(Symbol parallel_desc, diff --git a/oneflow/core/job/parallel_desc_test.cpp b/oneflow/core/job/parallel_desc_test.cpp index fdfc6c2b9ce..975581b43f7 100644 --- a/oneflow/core/job/parallel_desc_test.cpp +++ b/oneflow/core/job/parallel_desc_test.cpp @@ -25,29 +25,31 @@ namespace test { namespace { -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } +struct GlobaProcessCtxScope final { + GlobaProcessCtxScope(int64_t node_size, int64_t world_size) { + Global::New(); + auto* ctx = Global::Get(); + for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); } + ctx->set_rank(0); + ctx->set_node_size(node_size); + } + ~GlobaProcessCtxScope() { Global::Delete(); } +}; } // namespace TEST(ParallelDesc, continuous_1n4d) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); - DestroyNumProcessPerNode(); } TEST(ParallelDesc, continuous_1n4d_multi_process) { - InitNumProcessPerNode(); - Global::Get()->set_value(4); + GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); @@ -59,12 +61,10 @@ TEST(ParallelDesc, continuous_1n4d_multi_process) { ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 1), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 2), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 3), 1); - DestroyNumProcessPerNode(); } TEST(ParallelDesc, continuous_1n4d_multi_process_with_rank) { - InitNumProcessPerNode(); - Global::Get()->set_value(4); + GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("@0:0-3"); @@ -74,11 +74,10 @@ TEST(ParallelDesc, continuous_1n4d_multi_process_with_rank) { ASSERT_EQ(parallel_desc.parallel_num(), 4); ASSERT_EQ(machine_ids.size(), 1); ASSERT_EQ(std::count(machine_ids.begin(), machine_ids.end(), 0), 1); - DestroyNumProcessPerNode(); } TEST(ParallelDesc, discrete_1n4d) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(1, 4); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); @@ -86,11 +85,10 @@ TEST(ParallelDesc, discrete_1n4d) { ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 4); - DestroyNumProcessPerNode(); } TEST(ParallelDesc, continuous_2n8d) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-3"); @@ -98,11 +96,10 @@ TEST(ParallelDesc, continuous_2n8d) { ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 8); - DestroyNumProcessPerNode(); } TEST(ParallelDesc, discrete_2n8d) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(2, 8); ParallelConf parallel_conf; parallel_conf.set_device_tag("cpu"); parallel_conf.add_device_name("0:0-1"); @@ -112,7 +109,6 @@ TEST(ParallelDesc, discrete_2n8d) { ParallelDesc parallel_desc(parallel_conf); ASSERT_EQ(parallel_desc.device_tag(), "cpu"); ASSERT_EQ(parallel_desc.parallel_num(), 8); - DestroyNumProcessPerNode(); } } // namespace test diff --git a/oneflow/core/job/rank_group.cpp b/oneflow/core/job/rank_group.cpp new file mode 100644 index 00000000000..f1a104a12ec --- /dev/null +++ b/oneflow/core/job/rank_group.cpp @@ -0,0 +1,94 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/rpc/include/global_process_ctx.h" + +namespace oneflow { + +/*static*/ Maybe> RankGroup::New(const std::set& ranks) { + static thread_local std::map, Symbol> map; + auto iter = map.find(ranks); + if (iter == map.end()) { + RankGroup rank_group; + JUST(rank_group.Init(ranks)); + iter = map.emplace(ranks, SymbolOf(rank_group)).first; + } + return iter->second; +} + +namespace { + +Maybe> AllWorldRanks() { + const auto& ranks = std::make_shared>(); + for (int i = 0; i < GlobalProcessCtx::WorldSize(); ++i) { ranks->insert(i); } + return ranks; +} + +} // namespace + +/*static*/ Maybe> RankGroup::DefaultRankGroup() { + const auto& all_wold_ranks = JUST(AllWorldRanks()); + const auto& rank_group = JUST(RankGroup::New(*all_wold_ranks)); + return rank_group; +} + +Maybe RankGroup::Init(const std::set& ranks) { + ranks_ = ranks; + // Initialize rank2next_rank_in_ring_ and rank2prev_rank_in_ring_ + { + CHECK_GT_OR_RETURN(ranks.size(), 0); + int64_t last = *(--ranks.end()); + for (int64_t i : ranks) { + CHECK_OR_RETURN(rank2next_rank_in_ring_.emplace(last, i).second); + CHECK_OR_RETURN(rank2prev_rank_in_ring_.emplace(i, last).second); + last = i; + } + } + // Initialize hash_value_ + hash_value_ = 0; + for (int64_t i : ranks) { HashCombine(&hash_value_, i); } + return Maybe::Ok(); +} + +Maybe RankGroup::GetNextRankInRing(int64_t rank) const { + return MapAt(rank2next_rank_in_ring_, rank); +} + +Maybe RankGroup::GetNextRankInRing() const { + return GetNextRankInRing(GlobalProcessCtx::Rank()); +} + +Maybe RankGroup::GetPrevRankInRing(int64_t rank) const { + return MapAt(rank2prev_rank_in_ring_, rank); +} + +Maybe RankGroup::GetPrevRankInRing() const { + return GetPrevRankInRing(GlobalProcessCtx::Rank()); +} + +bool RankGroup::ContainingCurrentRank() const { + return rank2next_rank_in_ring_.count(GlobalProcessCtx::Rank()) > 0; +} + +Maybe RankGroup::ForEachRank(const std::function(int64_t)>& DoEach) const { + for (int64_t i : ranks_) { JUST(DoEach(i)); } + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/job/rank_group.h b/oneflow/core/job/rank_group.h new file mode 100644 index 00000000000..1eceadcf812 --- /dev/null +++ b/oneflow/core/job/rank_group.h @@ -0,0 +1,71 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_RANK_GROUP_H_ +#define ONEFLOW_CORE_JOB_RANK_GROUP_H_ + +#include +#include +#include +#include +#include +#include +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +class RankGroup final { + public: + ~RankGroup() = default; + + static Maybe> New(const std::set& ranks); + static Maybe> DefaultRankGroup(); + + bool operator==(const RankGroup& that) const { return this->ranks_ == that.ranks_; } + bool operator!=(const RankGroup& that) const { return !(*this == that); } + + size_t size() const { return ranks_.size(); } + size_t hash_value() const { return hash_value_; } + Maybe GetNextRankInRing(int64_t rank) const; + Maybe GetNextRankInRing() const; + Maybe GetPrevRankInRing(int64_t rank) const; + Maybe GetPrevRankInRing() const; + bool ContainingCurrentRank() const; + + Maybe ForEachRank(const std::function(int64_t)>&) const; + + private: + RankGroup() = default; + Maybe Init(const std::set& ranks); + + std::set ranks_; + std::unordered_map rank2next_rank_in_ring_; + std::unordered_map rank2prev_rank_in_ring_; + size_t hash_value_; +}; + +} // namespace oneflow + +namespace std { + +template<> +struct hash final { + size_t operator()(const oneflow::RankGroup& rank_group) const { return rank_group.hash_value(); } +}; + +} // namespace std + +#endif // ONEFLOW_CORE_JOB_RANK_GROUP_H_ diff --git a/oneflow/core/job/rank_group_scope.cpp b/oneflow/core/job/rank_group_scope.cpp new file mode 100644 index 00000000000..1155a49fe57 --- /dev/null +++ b/oneflow/core/job/rank_group_scope.cpp @@ -0,0 +1,77 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/job/rank_group_scope.h" + +namespace oneflow { + +namespace { + +const RankGroupScope** MutThreadLocalRankGroupScope() { + static thread_local const RankGroupScope* scope = nullptr; + return &scope; +} + +} // namespace + +RankGroupScope::RankGroupScope(Symbol rank_group, const RankGroupScope* parent, + const RankGroupScope* root) + : rank_group_(rank_group), parent_(parent), root_(root) { + CHECK_EQ(parent, *MutThreadLocalRankGroupScope()); + *MutThreadLocalRankGroupScope() = this; +} + +Maybe RankGroupScope::SetRootSelf() { + CHECK_ISNULL_OR_RETURN(parent_); + CHECK_ISNULL_OR_RETURN(root_); + root_ = this; + return Maybe::Ok(); +} + +RankGroupScope::~RankGroupScope() { + CHECK_EQ(this, *MutThreadLocalRankGroupScope()); + *MutThreadLocalRankGroupScope() = this->parent_; +} + +/*static*/ Maybe RankGroupScope::MakeInitialRankGroupScope( + Symbol rank_group) { + CHECK_ISNULL_OR_RETURN(*MutThreadLocalRankGroupScope()); + auto* ptr = new RankGroupScope(rank_group, nullptr, nullptr); + JUST(ptr->SetRootSelf()); + return std::shared_ptr(ptr); +} + +/*static*/ Maybe RankGroupScope::MakeNestedRankGroupScope( + Symbol rank_group) { + const auto* parent = *MutThreadLocalRankGroupScope(); + CHECK_NOTNULL_OR_RETURN(parent); + const auto* root = &parent->root(); + auto* ptr = new RankGroupScope(rank_group, parent, root); + return std::shared_ptr(ptr); +} + +/*static*/ Maybe> RankGroupScope::CurrentRankGroup() { + const RankGroupScope* scope = *MutThreadLocalRankGroupScope(); + CHECK_NOTNULL_OR_RETURN(scope); + return scope->rank_group(); +} + +/*static*/ Maybe> RankGroupScope::RootRankGroup() { + const RankGroupScope* scope = *MutThreadLocalRankGroupScope(); + CHECK_NOTNULL_OR_RETURN(scope); + return scope->root().rank_group(); +} + +} // namespace oneflow diff --git a/oneflow/core/job/rank_group_scope.h b/oneflow/core/job/rank_group_scope.h new file mode 100644 index 00000000000..cf5b6e818a6 --- /dev/null +++ b/oneflow/core/job/rank_group_scope.h @@ -0,0 +1,52 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ +#define ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ + +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/common/symbol.h" + +namespace oneflow { + +class RankGroupScope final { + public: + ~RankGroupScope(); + + Symbol rank_group() const { return rank_group_; } + const RankGroupScope& root() const { return *root_; } + + static Maybe MakeNestedRankGroupScope(Symbol rank_group); + + static Maybe MakeInitialRankGroupScope(Symbol rank_group); + + static Maybe> CurrentRankGroup(); + + static Maybe> RootRankGroup(); + + private: + RankGroupScope(Symbol rank_group, const RankGroupScope* parent, + const RankGroupScope* root); + + Maybe SetRootSelf(); + + Symbol rank_group_; + const RankGroupScope* parent_; + const RankGroupScope* root_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_RANK_GROUP_SCOPE_H_ diff --git a/oneflow/core/job/rank_group_scope_test.cpp b/oneflow/core/job/rank_group_scope_test.cpp new file mode 100644 index 00000000000..cad1a92de43 --- /dev/null +++ b/oneflow/core/job/rank_group_scope_test.cpp @@ -0,0 +1,64 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include "oneflow/core/common/util.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/control/ctrl_bootstrap.pb.h" + +namespace oneflow { +namespace test { + +TEST(RankGroupScope, initial) { + const auto& rank_group0 = CHECK_JUST(RankGroup::New(std::set{0, 1, 2})); + auto rank_group_scope = CHECK_JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group0)); + int64_t rank = 0; + const auto& rank_group = CHECK_JUST(RankGroupScope::CurrentRankGroup()); + rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); + ASSERT_EQ(rank, 1); + rank_group_scope.reset(); + ASSERT_FALSE(TRY(RankGroupScope::CurrentRankGroup()).IsOk()); +} + +TEST(RankGroupScope, nonconsecutive_rank) { + const auto& rank_group0 = CHECK_JUST(RankGroup::New(std::set{0, 1, 2})); + auto rank_group_scope0 = CHECK_JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group0)); + int64_t rank = 0; + const auto& rank_group = CHECK_JUST(RankGroupScope::CurrentRankGroup()); + rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); + ASSERT_EQ(rank, 1); + rank = CHECK_JUST(rank_group->GetNextRankInRing(2)); + ASSERT_EQ(rank, 0); + { + const auto& rank_group1 = CHECK_JUST(RankGroup::New(std::set{0, 1})); + auto rank_group_scope1 = CHECK_JUST(RankGroupScope::MakeNestedRankGroupScope(rank_group1)); + { + const auto& rank_group2 = CHECK_JUST(RankGroup::New(std::set{0})); + auto rank_group_scope2 = CHECK_JUST(RankGroupScope::MakeNestedRankGroupScope(rank_group2)); + const auto& current_rank_group = CHECK_JUST(RankGroupScope::CurrentRankGroup()); + ASSERT_TRUE(rank_group2 == current_rank_group); + const auto& root_rank_group = CHECK_JUST(RankGroupScope::RootRankGroup()); + ASSERT_TRUE(rank_group == root_rank_group); + rank_group_scope2.reset(); + } + rank_group_scope1.reset(); + } + rank_group_scope0.reset(); +} + +} // namespace test +} // namespace oneflow diff --git a/oneflow/core/job/rank_group_test.cpp b/oneflow/core/job/rank_group_test.cpp new file mode 100644 index 00000000000..1110f8d0c2c --- /dev/null +++ b/oneflow/core/job/rank_group_test.cpp @@ -0,0 +1,62 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include "oneflow/core/common/util.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/control/ctrl_bootstrap.pb.h" + +namespace oneflow { +namespace test { + +TEST(RankGroup, two_rank) { + const auto& rank_group = CHECK_JUST(RankGroup::New(std::set{0, 1})); + int64_t rank = 0; + rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); + ASSERT_EQ(rank, 1); + rank = CHECK_JUST(rank_group->GetNextRankInRing(1)); + ASSERT_EQ(rank, 0); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(0)); + ASSERT_EQ(rank, 1); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(1)); + ASSERT_EQ(rank, 0); +} + +TEST(RankGroup, nonconsecutive_rank) { + const auto& rank_group = CHECK_JUST(RankGroup::New(std::set{0, 1, 3, 4})); + int64_t rank = 0; + rank = CHECK_JUST(rank_group->GetNextRankInRing(0)); + ASSERT_EQ(rank, 1); + rank = CHECK_JUST(rank_group->GetNextRankInRing(1)); + ASSERT_EQ(rank, 3); + rank = CHECK_JUST(rank_group->GetNextRankInRing(3)); + ASSERT_EQ(rank, 4); + rank = CHECK_JUST(rank_group->GetNextRankInRing(4)); + ASSERT_EQ(rank, 0); + bool is_ok = TRY(rank_group->GetNextRankInRing(2)).IsOk(); + ASSERT_FALSE(is_ok); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(1)); + ASSERT_EQ(rank, 0); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(3)); + ASSERT_EQ(rank, 1); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(4)); + ASSERT_EQ(rank, 3); + rank = CHECK_JUST(rank_group->GetPrevRankInRing(0)); + ASSERT_EQ(rank, 4); +} + +} // namespace test +} // namespace oneflow diff --git a/oneflow/core/rpc/lib/global_process_ctx.cpp b/oneflow/core/rpc/lib/global_process_ctx.cpp index ad4eca86a67..9d5d670bb92 100644 --- a/oneflow/core/rpc/lib/global_process_ctx.cpp +++ b/oneflow/core/rpc/lib/global_process_ctx.cpp @@ -49,9 +49,6 @@ int64_t GlobalProcessCtx::ThisNodeId() { } int64_t GlobalProcessCtx::NumOfProcessPerNode() { - if (Global::Get() != nullptr) { - return int64_t(Global::Get()->value()); - } CHECK_NOTNULL(Global::Get()); CHECK_EQ(WorldSize() % NodeSize(), 0); return int64_t(WorldSize() / NodeSize()); diff --git a/oneflow/core/thread/consistent_unique_id.cpp b/oneflow/core/thread/consistent_unique_id.cpp new file mode 100644 index 00000000000..6fb5ce3364e --- /dev/null +++ b/oneflow/core/thread/consistent_unique_id.cpp @@ -0,0 +1,78 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/thread/consistent_unique_id.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/rpc_util.h" +#include "oneflow/core/common/container_util.h" + +namespace oneflow { + +namespace { + +class ConsistentUniqueIdStorage final { + public: + ConsistentUniqueIdStorage() = default; + ~ConsistentUniqueIdStorage() = default; + + static ConsistentUniqueIdStorage* Singleton() { + static auto* storage = new ConsistentUniqueIdStorage(); + return storage; + } + + Maybe Emplace(int64_t id, const std::string& debug_string) { + std::unique_lock lock(mutex_); + CHECK_LE_OR_RETURN(id2debug_string_.size(), RpcToken::MaxNumberOfThreadConsistentUId()); + for (const auto& pair : id2debug_string_) { CHECK_NE_OR_RETURN(debug_string, pair.second); } + CHECK_OR_RETURN(id2debug_string_.emplace(id, debug_string).second); + return Maybe::Ok(); + } + + Maybe DebugString(int64_t id) const { + std::unique_lock lock(mutex_); + return MapAt(id2debug_string_, id); + } + + private: + mutable std::mutex mutex_; + HashMap id2debug_string_; +}; + +std::unique_ptr* MutThreadLocalConsistentUniqueId() { + static thread_local std::unique_ptr consistent_uid; + return &consistent_uid; +} + +} // namespace + +Maybe SetThisThreadConsistentUniqueId(int64_t id, const std::string& debug_string) { + JUST(ConsistentUniqueIdStorage::Singleton()->Emplace(id, debug_string)); + auto* ptr = MutThreadLocalConsistentUniqueId(); + CHECK_ISNULL_OR_RETURN(ptr->get()); + ptr->reset(new int64_t(id)); + return Maybe::Ok(); +} + +Maybe GetThisThreadConsistentUniqueId() { + auto* ptr = MutThreadLocalConsistentUniqueId(); + CHECK_NOTNULL_OR_RETURN(ptr->get()); + return **ptr; +} + +Maybe GetConsistentUniqueIdDebugString(int64_t thread_consistent_unique_id) { + return ConsistentUniqueIdStorage::Singleton()->DebugString(thread_consistent_unique_id); +} + +} // namespace oneflow diff --git a/oneflow/core/thread/consistent_unique_id.h b/oneflow/core/thread/consistent_unique_id.h new file mode 100644 index 00000000000..7e4bc65582e --- /dev/null +++ b/oneflow/core/thread/consistent_unique_id.h @@ -0,0 +1,31 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_THREAD_CONSISTENT_UNIQUE_ID_H_ +#define ONEFLOW_CORE_THREAD_CONSISTENT_UNIQUE_ID_H_ + +#include +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +Maybe SetThisThreadConsistentUniqueId(int64_t thread_consistent_uid, + const std::string& debug_string); +Maybe GetThisThreadConsistentUniqueId(); +Maybe GetThreadConsistentUniqueIdDebugString(int64_t thread_consistent_uid); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_THREAD_CONSISTENT_UNIQUE_ID_H_ diff --git a/oneflow/core/vm/nop_stream_type_test.cpp b/oneflow/core/vm/nop_stream_type_test.cpp index e13dd06b20c..181d6b5306b 100644 --- a/oneflow/core/vm/nop_stream_type_test.cpp +++ b/oneflow/core/vm/nop_stream_type_test.cpp @@ -35,13 +35,6 @@ namespace test { namespace { -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } - ObjectMsgPtr NaiveNewVirtualMachine(const VmDesc& vm_desc) { return ObjectMsgPtr::New(vm_desc); } @@ -87,16 +80,18 @@ void TestNopStreamTypeNoArgument( ASSERT_EQ(instruction->mut_instr_msg(), nop_instr_msg.Mutable()); } -TEST(NopStreamType, no_argument) { TestNopStreamTypeNoArgument(&NaiveNewVirtualMachine); } +TEST(NopStreamType, no_argument) { + TestResourceDescScope scope(1, 1); + TestNopStreamTypeNoArgument(&NaiveNewVirtualMachine); +} TEST(NopStreamType, cached_allocator_no_argument) { + TestResourceDescScope scope(1, 1); TestNopStreamTypeNoArgument(CachedAllocatorNewVirtualMachine()); } void TestNopStreamTypeOneArgument( std::function(const VmDesc&)> NewVirtualMachine) { - InitNumProcessPerNode(); - TestResourceDescScope scope(1, 1); auto vm_desc = ObjectMsgPtr::New(TestUtil::NewVmResourceDesc().Get()); TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"Nop", "NewObject"}); auto vm = NewVirtualMachine(vm_desc.Get()); @@ -114,19 +109,20 @@ void TestNopStreamTypeOneArgument( vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } TEST(NopStreamType, one_argument_dispatch) { + TestResourceDescScope scope(1, 1); TestNopStreamTypeOneArgument(&NaiveNewVirtualMachine); } TEST(NopStreamType, cached_allocator_one_argument_dispatch) { + TestResourceDescScope scope(1, 1); TestNopStreamTypeOneArgument(CachedAllocatorNewVirtualMachine()); } TEST(NopStreamType, one_argument_triger_next_instruction) { - InitNumProcessPerNode(); + TestResourceDescScope scope(1, 1); auto vm_desc = ObjectMsgPtr::New(TestUtil::NewVmResourceDesc().Get()); TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"Nop", "NewObject"}); auto vm = NaiveNewVirtualMachine(vm_desc.Get()); @@ -143,11 +139,10 @@ TEST(NopStreamType, one_argument_triger_next_instruction) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } TEST(NopStreamType, one_argument_triger_all_instructions) { - InitNumProcessPerNode(); + TestResourceDescScope scope(1, 1); auto vm_desc = ObjectMsgPtr::New(TestUtil::NewVmResourceDesc().Get()); TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"Nop", "NewObject"}); auto vm = NaiveNewVirtualMachine(vm_desc.Get()); @@ -164,7 +159,6 @@ TEST(NopStreamType, one_argument_triger_all_instructions) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } } // namespace diff --git a/oneflow/core/vm/object_instruction_type_test.cpp b/oneflow/core/vm/object_instruction_type_test.cpp index d557c753ebf..c0fe0e280b1 100644 --- a/oneflow/core/vm/object_instruction_type_test.cpp +++ b/oneflow/core/vm/object_instruction_type_test.cpp @@ -36,15 +36,19 @@ namespace test { namespace { -void InitNumProcessPerNode() { - Global::New(); - Global::Get()->set_value(1); -} - -void DestroyNumProcessPerNode() { Global::Delete(); } +struct GlobaProcessCtxScope final { + GlobaProcessCtxScope(int64_t node_size, int64_t world_size) { + Global::New(); + auto* ctx = Global::Get(); + for (int i = 0; i < world_size; ++i) { ctx->mutable_ctrl_addr()->Add(); } + ctx->set_rank(0); + ctx->set_node_size(node_size); + } + ~GlobaProcessCtxScope() { Global::Delete(); } +}; TEST(ControlStreamType, new_object) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(1, 1); auto vm_desc = ObjectMsgPtr::New(TestUtil::NewVmResourceDesc().Get()); TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"NewObject"}); CachedObjectMsgAllocator allocator(20, 100); @@ -57,11 +61,10 @@ TEST(ControlStreamType, new_object) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } TEST(ControlStreamType, delete_object) { - InitNumProcessPerNode(); + GlobaProcessCtxScope scope(1, 1); auto vm_desc = ObjectMsgPtr::New(TestUtil::NewVmResourceDesc().Get()); TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"NewObject"}); CachedObjectMsgAllocator allocator(20, 100); @@ -75,7 +78,6 @@ TEST(ControlStreamType, delete_object) { vm->Schedule(); OBJECT_MSG_LIST_FOR_EACH_PTR(vm->mut_thread_ctx_list(), t) { t->TryReceiveAndRun(); } } - DestroyNumProcessPerNode(); } } // namespace diff --git a/oneflow/core/vm/sequential_instruction_type_test.cpp b/oneflow/core/vm/sequential_instruction_type_test.cpp index f4e4047c179..683eb404880 100644 --- a/oneflow/core/vm/sequential_instruction_type_test.cpp +++ b/oneflow/core/vm/sequential_instruction_type_test.cpp @@ -44,15 +44,11 @@ namespace { struct GlobalProcessCtxScope { GlobalProcessCtxScope() { auto* ctx = Global::New(); + ctx->mutable_ctrl_addr()->Add(); ctx->set_rank(0); ctx->set_node_size(1); - Global::New(); - Global::Get()->set_value(1); - } - ~GlobalProcessCtxScope() { - Global::Delete(); - Global::Delete(); } + ~GlobalProcessCtxScope() { Global::Delete(); } }; TEST(SequentialInstruction, front_seq_compute) { diff --git a/oneflow/core/vm/test_util.cpp b/oneflow/core/vm/test_util.cpp index 851e22f3671..74ed0890d2b 100644 --- a/oneflow/core/vm/test_util.cpp +++ b/oneflow/core/vm/test_util.cpp @@ -41,6 +41,10 @@ EnvProto GetEnvProto(int64_t machine_num) { TestResourceDescScope::TestResourceDescScope(int64_t gpu_device_num, int64_t cpu_device_num, int64_t machine_num) { + Global::New(); + Global::Get()->mutable_ctrl_addr()->Add(); + Global::Get()->set_rank(0); + Global::Get()->set_node_size(1); EnvProto env_proto = GetEnvProto(machine_num); Global::New(env_proto); Resource resource; @@ -52,8 +56,8 @@ TestResourceDescScope::TestResourceDescScope(int64_t gpu_device_num, int64_t cpu TestResourceDescScope::~TestResourceDescScope() { Global::Delete(); - Global::Delete(); Global::Delete(); + Global::Delete(); } ObjectMsgPtr TestUtil::NewVmResourceDesc(int64_t device_num, int64_t machine_num) { diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 7131a9a8482..19f5f388325 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -79,6 +79,7 @@ def is_deprecated(func_or_class): env_util.SetDefaultMultiClientEnvVars() oneflow._oneflow_internal.SetIsMultiClient(True) env_util.api_env_init() +oneflow._oneflow_internal.InitDefaultConsistentRpcTokenScope() session_ctx.OpenDefaultSession( MultiClientSession(oneflow._oneflow_internal.NewSessionId()) )