Skip to content

Commit

Permalink
New sync consistent meta info (#5634)
Browse files Browse the repository at this point in the history
* rebase

* check in gen py

* merge master and fix bugs

* address pr comments

* address pr comments

* auto format by CI

* rebase

* address pr comments

* auto format by CI

* functional python_arg

* reuse ctrl rpc token for avoiding long time timeout waiting.

* fix compiler complaints

* auto format by CI

* auto format by CI

* remove unused files

* fix return type error on gcc 4.8.5

Signed-off-by: daquexian <daquexian566@gmail.com>

* auto format by CI

* fix return type error in xrt

Signed-off-by: daquexian <daquexian566@gmail.com>

* fix tick ibn sbp signature

* auto format by CI

Co-authored-by: tsai <jackalcooper@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
5 people committed Jul 31, 2021
1 parent e81dafc commit 0a44b54
Show file tree
Hide file tree
Showing 58 changed files with 2,345 additions and 157 deletions.
20 changes: 20 additions & 0 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/framework/tensor_method.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stride.h"
Expand Down Expand Up @@ -247,6 +248,13 @@ void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMe
return RegisterTensorHook(self, hook).GetOrThrow();
}

Maybe<void> CheckConsistentTensorMeta(const one::Tensor& tensor, int64_t seconds) {
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(tensor));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
JUST(ctx->Check());
return Maybe<void>::Ok();
}

bool ApiIsContiguous(const std::shared_ptr<Tensor>& tensor) {
return IsContiguous(tensor).GetOrThrow();
}
Expand Down Expand Up @@ -403,6 +411,18 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
.def("rpc_token",
[](const one::Tensor& tensor) -> int64_t {
return static_cast<uint64_t>(tensor.rpc_token().GetOrThrow());
})
.def("check_meta_consistency",
[](const one::Tensor& tensor) {
return CheckConsistentTensorMeta(tensor, 60 * 5).GetOrThrow();
})
.def("check_meta_consistency",
[](const one::Tensor& tensor, int64_t seconds) {
return CheckConsistentTensorMeta(tensor, seconds).GetOrThrow();
})
#define DEFINE_TENSOR_METHOD(T, type_proto) \
.def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \
.def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>)
Expand Down
60 changes: 60 additions & 0 deletions oneflow/api/python/rpc/consistent_rpc_token_scope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/thread/consistent_unique_id.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"

namespace py = pybind11;

namespace oneflow {

namespace {

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid,
Symbol<RankGroup> rank_group) {
JUST(SetThisThreadConsistentUniqueId(thread_consistent_uid, thread_tag));
static thread_local const auto& init_rank_group_scope =
JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group));
// no unused warning for `init_rank_group_scope`.
(void)(init_rank_group_scope);
return Maybe<void>::Ok();
}

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
JUST(InitConsistentRpcTokenScope(thread_tag, thread_consistent_uid, rank_group));
return Maybe<void>::Ok();
}

void ApiInitDefaultConsistentRpcTokenScope() {
return InitConsistentRpcTokenScope("main", 0).GetOrThrow();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitDefaultConsistentRpcTokenScope", &ApiInitDefaultConsistentRpcTokenScope);
}

} // namespace oneflow
47 changes: 47 additions & 0 deletions oneflow/api/python/rpc/rank_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"

namespace py = pybind11;

namespace oneflow {

namespace {

Maybe<void> CheckCurrentRankGroupConsistency(int64_t seconds) {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
const auto& ctx = JUST(CheckRpcToken(rank_group));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
return Maybe<void>::Ok();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("check_current_rank_group_consistency",
[](int64_t seconds) { return CheckCurrentRankGroupConsistency(seconds).GetOrThrow(); });
m.def("check_current_rank_group_consistency",
[]() { return CheckCurrentRankGroupConsistency(60 * 5).GetOrThrow(); });
}

} // namespace oneflow
1 change: 1 addition & 0 deletions oneflow/api/python/symbol/placement_symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ struct PlacementSymbolExportUtil {
return placement_str;
}
};

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/common/error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ Error Error::IndexError() {
return error;
}

Error Error::TimeoutError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_timeout_error();
return error;
}

Error Error::JobNameExistError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_exist_error();
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Error final {
static Error DeviceTagNotFoundError();
static Error ValueError(const std::string& error_summary);
static Error IndexError();
static Error TimeoutError();
static Error JobNameExistError();
static Error JobNameEmptyError();
static Error JobNameNotEqualError();
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/common/error.proto
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ message ValueError {}

message IndexError {}

message TimeoutError {}

message ErrorProto {
optional string error_summary = 1 [default = ""];
optional string msg = 2 [default = ""];
Expand All @@ -148,6 +150,7 @@ message ErrorProto {
DeviceTagNotFoundError device_tag_not_found_error = 26;
ValueError value_error = 27;
IndexError index_error = 28;
TimeoutError timeout_error = 29;
JobNameExistError job_name_exist_error = 100;
JobNameEmptyError job_name_empty_error = 101;
JobNameNotEqualError job_name_not_equal_error = 102;
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Exception : public std::exception {
OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong) \
OF_PP_MAKE_TUPLE_SEQ(Value) \
OF_PP_MAKE_TUPLE_SEQ(Index) \
OF_PP_MAKE_TUPLE_SEQ(Timeout) \
OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch)

#define DEFINE_EXCEPTION_CLASS(cls) \
Expand Down
36 changes: 36 additions & 0 deletions oneflow/core/common/flat_shape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/flat_shape.h"
#include "oneflow/core/common/shape.h"

namespace oneflow {

Maybe<void> FlatShape::Init(const std::shared_ptr<const Shape>& shape) {
CHECK_LE_OR_RETURN(shape->NumAxes(), SHAPE_MAX_AXIS_SIZE);
this->set_num_axes(shape->NumAxes());
for (int i = 0; i < this->num_axes(); ++i) { *this->mutable_dim()->Mutable(i) = shape->At(i); }
return Maybe<void>::Ok();
}

Maybe<void> FlatShape::Check(const std::shared_ptr<const Shape>& shape) const {
CHECK_EQ_OR_RETURN(this->num_axes(), shape->NumAxes());
for (int i = 0; i < this->num_axes(); ++i) {
CHECK_EQ_OR_RETURN(this->dim().Get(i), shape->At(i));
}
return Maybe<void>::Ok();
}

} // namespace oneflow
44 changes: 44 additions & 0 deletions oneflow/core/common/flat_shape.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
#define ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_

#include <memory>
#include "oneflow/core/object_msg/flat_msg.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/shape_vec.h"

namespace oneflow {

class Shape;

// clang-format off

FLAT_MSG_BEGIN(FlatShape);
// Methods
OF_PUBLIC Maybe<void> Init(const std::shared_ptr<const Shape>& shape);
OF_PUBLIC Maybe<void> Check(const std::shared_ptr<const Shape>& shape) const;

// Fields
FLAT_MSG_DEFINE_OPTIONAL(int64_t, num_axes);
FLAT_MSG_DEFINE_REPEATED(int64_t, dim, SHAPE_MAX_AXIS_SIZE);
FLAT_MSG_END(FlatShape);

// clang-format on

} // namespace oneflow

#endif // ONEFLOW_CORE_COMMON_FLAT_SHAPE_H_
6 changes: 6 additions & 0 deletions oneflow/core/common/maybe.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,18 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina
SharedOrScalar<cfg::ErrorProto, void*> error_or_scalar_;
};

inline const std::shared_ptr<cfg::ErrorProto>& UninitializedValueError() {
static thread_local const auto& error = Error::ValueError("uninitialized value").error_proto();
return error;
}

template<typename T>
class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {
public:
Maybe(T data) : error_or_scalar_(data) {}
Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); }
Maybe(const std::shared_ptr<cfg::ErrorProto>& error) : error_or_scalar_(error) { CheckError(); }
Maybe() : error_or_scalar_(UninitializedValueError()) {}
Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default;
~Maybe() = default;
Expand Down
4 changes: 0 additions & 4 deletions oneflow/core/control/ctrl_bootstrap.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,3 @@ message BootstrapConf {
optional int32 ctrl_port = 5 [default = -1];
optional int64 node_size = 6 [default = -1];
}

message NumProcessPerNode {
required int64 value = 1;
}
12 changes: 0 additions & 12 deletions oneflow/core/eager/init_symbol_instruction_type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ namespace oneflow {
namespace vm {
namespace test {

namespace {

void InitNumProcessPerNode() { Global<NumProcessPerNode>::New()->set_value(1); }

void DestroyNumProcessPerNode() { Global<NumProcessPerNode>::Delete(); }

} // namespace

using InstructionMsgList = OBJECT_MSG_LIST(vm::InstructionMsg, instr_msg_link);

template<typename T, typename SerializedT>
Expand All @@ -72,25 +64,21 @@ void TestInitSymbolInstructionType(const std::string& instr_type_name) {
}

TEST(InitSymbolInstructionType, job_desc) {
InitNumProcessPerNode();
#ifdef WITH_CUDA
vm::TestResourceDescScope resource_scope(1, 1);
#else
vm::TestResourceDescScope resource_scope(0, 1);
#endif
TestInitSymbolInstructionType<JobDesc, JobConfigProto>("InitJobDescSymbol");
DestroyNumProcessPerNode();
}

TEST(InitSymbolInstructionType, operator_conf) {
InitNumProcessPerNode();
#ifdef WITH_CUDA
vm::TestResourceDescScope resource_scope(1, 1);
#else
vm::TestResourceDescScope resource_scope(0, 1);
#endif
TestInitSymbolInstructionType<OperatorConfSymbol, OperatorConf>("InitOperatorConfSymbol");
DestroyNumProcessPerNode();
}

} // namespace test
Expand Down
15 changes: 0 additions & 15 deletions oneflow/core/eager/lazy_job_instruction_type_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,6 @@ namespace oneflow {
namespace vm {
namespace test {

namespace {

void InitNumProcessPerNode() {
Global<NumProcessPerNode>::New();
Global<NumProcessPerNode>::Get()->set_value(1);
}

void DestroyNumProcessPerNode() { Global<NumProcessPerNode>::Delete(); }

} // namespace

using InstructionMsgList = OBJECT_MSG_LIST(vm::InstructionMsg, instr_msg_link);

class NoArgNoRetMockNNGraph : public NNGraphIf {
Expand All @@ -73,7 +62,6 @@ class NoArgNoRetMockNNGraph : public NNGraphIf {
};

TEST(RunLazyJobInstructionType, simple) {
InitNumProcessPerNode();
vm::TestResourceDescScope resource_scope(0, 1);
auto vm_desc = ObjectMsgPtr<vm::VmDesc>::New(vm::TestUtil::NewVmResourceDesc().Get());
vm::TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"RunLazyJob"});
Expand Down Expand Up @@ -122,11 +110,9 @@ TEST(RunLazyJobInstructionType, simple) {
leave_thread.join();
enter_thread.join();
Global<BufferMgr<std::shared_ptr<JobInstance>>>::Delete();
DestroyNumProcessPerNode();
}

TEST(RunLazyJobInstructionType, wait_for_another_job_finished) {
InitNumProcessPerNode();
vm::TestResourceDescScope resource_scope(0, 1);
auto vm_desc = ObjectMsgPtr<vm::VmDesc>::New(vm::TestUtil::NewVmResourceDesc().Get());
vm::TestUtil::AddStreamDescByInstrNames(vm_desc.Mutable(), {"RunLazyJob"});
Expand Down Expand Up @@ -247,7 +233,6 @@ TEST(RunLazyJobInstructionType, wait_for_another_job_finished) {
enter_thread0.join();
enter_thread1.join();
Global<BufferMgr<std::shared_ptr<JobInstance>>>::Delete();
DestroyNumProcessPerNode();
}

} // namespace test
Expand Down

0 comments on commit 0a44b54

Please sign in to comment.