Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New sync consistent meta info #5634

Merged
merged 76 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
d248b3d
rebase
jackalcooper Jul 26, 2021
68bb49f
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 26, 2021
6e8a083
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
jackalcooper Jul 26, 2021
f7aa67f
check in gen py
jackalcooper Jul 26, 2021
9a2f9d5
Merge branch 'deduce_consistent_op_interpreter' of github.com:Oneflow…
lixinqi Jul 26, 2021
122acb7
merge master and fix bugs
lixinqi Jul 26, 2021
43a2ac5
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 26, 2021
0101184
address pr comments
lixinqi Jul 26, 2021
f2b0c79
address pr comments
lixinqi Jul 26, 2021
e670f76
merge master
lixinqi Jul 26, 2021
5be290a
auto format by CI
oneflow-ci-bot Jul 27, 2021
e3db142
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 27, 2021
9f8bd32
rebase
lixinqi Jul 27, 2021
cf5255e
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 27, 2021
522ee5a
address pr comments
lixinqi Jul 27, 2021
45cf7b9
auto format by CI
oneflow-ci-bot Jul 27, 2021
8bdb533
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 27, 2021
11fd3ee
functional python_arg
lixinqi Jul 27, 2021
d0ccb71
Merge branch 'deduce_consistent_op_interpreter' into new_sync_consist…
lixinqi Jul 27, 2021
fd5cd19
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 27, 2021
3d7e511
Merge branch 'deduce_consistent_op_interpreter' into new_sync_consist…
lixinqi Jul 27, 2021
0257dc0
reuse ctrl rpc token for avoiding long time timeout waiting.
lixinqi Jul 27, 2021
d31afb9
Merge branch 'new_sync_consistent_meta_info' of github.com:Oneflow-In…
lixinqi Jul 27, 2021
84a91b5
merge master
lixinqi Jul 28, 2021
cb4695e
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 28, 2021
49987b9
Merge branch 'deduce_consistent_op_interpreter' into new_sync_consist…
lixinqi Jul 28, 2021
4c8fb6b
fix compiler complaints
lixinqi Jul 28, 2021
ff09a67
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 28, 2021
5748059
auto format by CI
oneflow-ci-bot Jul 28, 2021
5ae441e
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 28, 2021
e8a5da8
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 28, 2021
8250d4e
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 28, 2021
b060d98
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 28, 2021
d7313d7
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 28, 2021
196ce90
auto format by CI
oneflow-ci-bot Jul 28, 2021
4f3dc4b
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 28, 2021
a1c9789
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 28, 2021
47add8a
remove unused files
lixinqi Jul 29, 2021
88bec9e
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 29, 2021
268e388
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
a559ad8
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 29, 2021
1a9013f
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
bd1a6ea
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 29, 2021
994ad1c
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 29, 2021
c1de499
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
acefe4a
fix return type error on gcc 4.8.5
daquexian Jul 29, 2021
61fa201
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
b027fe8
auto format by CI
oneflow-ci-bot Jul 29, 2021
91672cb
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 29, 2021
539006c
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
d66388a
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 29, 2021
aa1313a
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
16efab2
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
430c777
Merge branch 'master' into deduce_consistent_op_interpreter
oneflow-ci-bot Jul 29, 2021
9699ed5
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 29, 2021
2621510
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
32d59ba
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
659af30
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 29, 2021
2f986f9
Merge branch 'master' into fix_no_return
oneflow-ci-bot Jul 30, 2021
946d431
fix return type error in xrt
daquexian Jul 30, 2021
8761227
Merge branch 'fix_no_return' of github.com:Oneflow-Inc/oneflow into f…
daquexian Jul 30, 2021
64b5c68
merge master
lixinqi Jul 30, 2021
3b2347d
Merge branch 'fix_no_return' into deduce_consistent_op_interpreter
lixinqi Jul 30, 2021
af2e270
fix tick ibn sbp signature
lixinqi Jul 30, 2021
7ab98b2
Merge branch 'master' into deduce_consistent_op_interpreter
lixinqi Jul 30, 2021
d6b0561
Merge branch 'deduce_consistent_op_interpreter' of github.com:Oneflow…
lixinqi Jul 30, 2021
48fe2b4
merge deduce_op_interpreter
lixinqi Jul 30, 2021
f86248c
Merge branch 'new_sync_consistent_meta_info' of github.com:Oneflow-In…
lixinqi Jul 30, 2021
093f7c8
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
a499124
auto format by CI
oneflow-ci-bot Jul 30, 2021
1d76e61
merge master
lixinqi Jul 30, 2021
b92c7e8
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
78ffa63
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
c80d736
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
ece303a
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
8eccef7
Merge branch 'master' into new_sync_consistent_meta_info
oneflow-ci-bot Jul 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions oneflow/api/python/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ Maybe<one::TensorTuple> CheckAndInitOutGrads(const one::TensorTuple& outputs,
CHECK_OR_RETURN(IsScalarTensor(*outputs.at(i)))
<< "Grad can be implicitly created only for scalar outputs";
const auto& ones_like = JUST(op_expr_helper::OnesLikeOp());
const auto& interpreter = JUST(one::OpInterpUtil::GetInterpreter());
one::TensorTuple grad_output(1);
JUST(interpreter->Apply(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output));
JUST(one::OpInterpUtil::Dispatch(*ones_like, one::TensorTuple{outputs.at(i)}, &grad_output));
gradients->at(i) = grad_output.at(0);
} else {
CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape()))
Expand Down
6 changes: 4 additions & 2 deletions oneflow/api/python/framework/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/common.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/device.h"
Expand Down Expand Up @@ -54,9 +55,10 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
}))
.def_property_readonly("type", [](const Symbol<Device>& d) { return d->type(); })
.def_property_readonly("index", [](const Symbol<Device>& d) { return d->device_id(); })
.def("__eq__", [](const Symbol<Device>& d1, const Symbol<Device>& d2) { return *d1 == *d2; })
.def("__str__", [](const Symbol<Device>& d) { return d->ToString(); })
.def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); });
.def("__repr__", [](const Symbol<Device>& d) { return d->ToRepr(); })
.def(py::self == py::self)
.def(py::hash(py::self));
}

} // namespace oneflow
3 changes: 1 addition & 2 deletions oneflow/api/python/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const one::TensorTuple&
<< "The operation requires " << op.input_size() << " inputs, but " << inputs.size()
<< " is given.";
auto outputs = std::make_shared<one::TensorTuple>(op.output_size());
auto interperter = JUST(one::OpInterpUtil::GetInterpreter());
JUST(interperter->Apply(op, inputs, outputs.get(), attrs));
JUST(one::OpInterpUtil::Dispatch(op, inputs, outputs.get(), attrs));
return outputs;
}

Expand Down
36 changes: 35 additions & 1 deletion oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "oneflow/core/common/tensor_buffer.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/framework/tensor_method.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/stride.h"
Expand Down Expand Up @@ -203,10 +204,30 @@ void ApiRegisterTensorHook(const std::shared_ptr<Tensor>& self, const AutogradMe
return RegisterTensorHook(self, hook).GetOrThrow();
}

Maybe<void> CheckConsistentTensorMeta(const one::Tensor& tensor, int64_t seconds) {
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(tensor));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
JUST(ctx->Check());
return Maybe<void>::Ok();
}

bool ApiIsContiguous(const std::shared_ptr<Tensor>& tensor) {
return IsContiguous(tensor).GetOrThrow();
}

Maybe<py::tuple> TensorGetPyTupleOfSbp(const Tensor& tensor) {
const auto& nd_sbp = JUST(tensor.parallel_distribution());
const auto& tuple = std::make_shared<py::tuple>(nd_sbp->sbp_parallel_size());
for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {
(*tuple)[i] = SymbolOf(nd_sbp->sbp_parallel(i));
}
return tuple;
}

py::tuple ApiTensorGetPyTupleOfSbp(const Tensor& tensor) {
return *TensorGetPyTupleOfSbp(tensor).GetPtrOrThrow();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
Expand Down Expand Up @@ -270,6 +291,18 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
.def("rpc_token",
[](const one::Tensor& tensor) -> int64_t {
return static_cast<uint64_t>(tensor.rpc_token().GetOrThrow());
})
.def("check_meta_consistency",
[](const one::Tensor& tensor) {
return CheckConsistentTensorMeta(tensor, 60 * 5).GetOrThrow();
})
.def("check_meta_consistency",
[](const one::Tensor& tensor, int64_t seconds) {
return CheckConsistentTensorMeta(tensor, seconds).GetOrThrow();
})
#define DEFINE_TENSOR_METHOD(T, type_proto) \
.def("_copy_to_numpy_" #T, &ApiCopyMirroredTensorToNumpy<T>) \
.def("_copy_from_numpy_" #T, &ApiCopyMirroredTensorFromNumpy<T>)
Expand All @@ -279,7 +312,8 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def("_get_copy_mirrored_tensor_from_numpy_func_name",
&ApiGetCopyMirroredTensorFromNumpyFuncName)
// consistent tensor only
.def_property_readonly("placement", &TensorGetParallelDesc);
.def_property_readonly("placement", &TensorGetParallelDesc)
.def_property_readonly("sbp", &ApiTensorGetPyTupleOfSbp);
}

} // namespace one
Expand Down
6 changes: 6 additions & 0 deletions oneflow/api/python/functional/python_arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "oneflow/core/common/data_type.cfg.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/user_op_attr.cfg.h"
Expand Down Expand Up @@ -168,6 +169,11 @@ Maybe<one::Generator> PythonArg::ObjectAs<one::Generator>() const {
return *JUST(detail::cast<std::shared_ptr<one::Generator>>(Borrow()));
}

template<>
Maybe<Symbol<Device>> PythonArg::ObjectAs<Symbol<Device>>() const {
return **JUST(detail::cast<std::shared_ptr<Symbol<Device>>>(Borrow()));
}

template<>
Maybe<Symbol<ParallelDesc>> PythonArg::ObjectAs<Symbol<ParallelDesc>>() const {
return **JUST(detail::cast<std::shared_ptr<Symbol<ParallelDesc>>>(Borrow()));
Expand Down
60 changes: 60 additions & 0 deletions oneflow/api/python/rpc/consistent_rpc_token_scope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/thread/consistent_unique_id.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"

namespace py = pybind11;

namespace oneflow {

namespace {

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid,
Symbol<RankGroup> rank_group) {
JUST(SetThisThreadConsistentUniqueId(thread_consistent_uid, thread_tag));
static thread_local const auto& init_rank_group_scope =
JUST(RankGroupScope::MakeInitialRankGroupScope(rank_group));
// no unused warning for `init_rank_group_scope`.
(void)(init_rank_group_scope);
return Maybe<void>::Ok();
}

Maybe<void> InitConsistentRpcTokenScope(const std::string& thread_tag,
int64_t thread_consistent_uid) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
JUST(InitConsistentRpcTokenScope(thread_tag, thread_consistent_uid, rank_group));
return Maybe<void>::Ok();
}

void ApiInitDefaultConsistentRpcTokenScope() {
return InitConsistentRpcTokenScope("main", 0).GetOrThrow();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("InitDefaultConsistentRpcTokenScope", &ApiInitDefaultConsistentRpcTokenScope);
}

} // namespace oneflow
47 changes: 47 additions & 0 deletions oneflow/api/python/rpc/rank_group.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/rank_group_rpc_util.h"
#include "oneflow/core/job/rank_group.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/common/symbol.h"

namespace py = pybind11;

namespace oneflow {

namespace {

Maybe<void> CheckCurrentRankGroupConsistency(int64_t seconds) {
const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup());
const auto& ctx = JUST(CheckRpcToken(rank_group));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, seconds));
return Maybe<void>::Ok();
}

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("check_current_rank_group_consistency",
[](int64_t seconds) { return CheckCurrentRankGroupConsistency(seconds).GetOrThrow(); });
m.def("check_current_rank_group_consistency",
[]() { return CheckCurrentRankGroupConsistency(60 * 5).GetOrThrow(); });
}

} // namespace oneflow
5 changes: 4 additions & 1 deletion oneflow/api/python/symbol/placement_symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ struct PlacementSymbolExportUtil {
return placement_str;
}
};

} // namespace

ONEFLOW_API_PYBIND11_MODULE("", m) {
Expand Down Expand Up @@ -239,7 +240,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
})
.def_property_readonly("hierarchy", [](Symbol<ParallelDesc> p) { return p->hierarchy(); })
.def("__str__", &PlacementSymbolExportUtil::PlacementSymbol2String)
.def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String);
.def("__repr__", &PlacementSymbolExportUtil::PlacementSymbol2String)
.def(py::self == py::self)
.def(py::hash(py::self));
m.def("AllDevicePlacement", &PlacementSymbolExportUtil::AllDevicePlacement);
}

Expand Down
5 changes: 4 additions & 1 deletion oneflow/api/python/symbol/sbp_symbol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/maybe.h"
Expand Down Expand Up @@ -88,7 +89,9 @@ ONEFLOW_API_PYBIND11_MODULE("sbp", m) {
m.attr("max_split_axis") = kMaxSplitAxis;
py::class_<Symbol<cfg::SbpParallel>, std::shared_ptr<Symbol<cfg::SbpParallel>>>(m, "sbp")
.def("__str__", &SbpParallelSymbolToString)
.def("__repr__", &SbpParallelSymbolToString);
.def("__repr__", &SbpParallelSymbolToString)
.def(py::self == py::self)
.def(py::hash(py::self));
m.def(
"split", [](int axis) { return GetSplitSbpParallel(axis).GetOrThrow(); }, py::arg("axis"));
m.def("broadcast", []() { return GetBroadcastSbpParallel().GetOrThrow(); });
Expand Down
37 changes: 34 additions & 3 deletions oneflow/core/autograd/autograd_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,40 @@ namespace oneflow {

namespace one {

TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {}

Maybe<Tensor> TensorInfo::zeros() const { return functional::Constant(*shape_.get(), 0, dtype_); }
TensorInfo::TensorInfo(const Tensor& tensor) : shape_(tensor.shape()), dtype_(tensor.dtype()) {
if (TRY(tensor.device()).IsOk()) { device_ = CHECK_JUST(tensor.device()); }
if (TRY(tensor.parallel_desc()).IsOk()) { parallel_desc_ = CHECK_JUST(tensor.parallel_desc()); }
if (TRY(tensor.parallel_distribution()).IsOk()) {
parallel_distribution_ = CHECK_JUST(tensor.parallel_distribution());
}
}

Maybe<const std::vector<Symbol<cfg::SbpParallel>>&> GetSbpTuple(
Symbol<cfg::ParallelDistribution> parallel_distribution) {
static thread_local HashMap<Symbol<cfg::ParallelDistribution>, std::vector<Symbol<cfg::SbpParallel>>> map;
auto iter = map.find(parallel_distribution);
if (iter == map.end()) {
std::vector<Symbol<cfg::SbpParallel>> sbp_tuple;
for (const auto& sbp_parallel : parallel_distribution->sbp_parallel()) {
sbp_tuple.push_back(SymbolOf(sbp_parallel));
}
iter = map.emplace(parallel_distribution, sbp_tuple).first;
}
return iter->second;
}

Maybe<Tensor> TensorInfo::zeros() const {
if (device_.has_value()) {
const auto& device = JUST(device_.value());
return functional::Constant(*shape_.get(), 0, dtype_, device);
} else {
const auto& parallel_desc = JUST(parallel_desc_.value());
const auto& parallel_distribution = JUST(parallel_distribution_.value());
const auto& sbp_tuple = JUST(GetSbpTuple(parallel_distribution));
return functional::ConsistentConstant(
*shape_.get(), 0, dtype_, parallel_desc, sbp_tuple);
}
}

} // namespace one

Expand Down
12 changes: 11 additions & 1 deletion oneflow/core/autograd/autograd_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,19 @@ limitations under the License.
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/framework/tensor_arg.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/optional.h"

namespace oneflow {

class Shape;

class Device;
class ParallelDesc;
namespace cfg {
class ParallelDistribution;
}

namespace one {

class Tensor;
Expand Down Expand Up @@ -86,7 +94,9 @@ class TensorInfo final {
private:
std::shared_ptr<const Shape> shape_;
DataType dtype_;
// TODO: Add device info
Optional<Symbol<Device>> device_; // for local tensor
Optional<Symbol<ParallelDesc>> parallel_desc_; // for consistent tensor
Optional<Symbol<cfg::ParallelDistribution>> parallel_distribution_; // for consistent tensor
};

} // namespace one
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/common/error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ Error Error::IndexError() {
return error;
}

Error Error::TimeoutError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_timeout_error();
return error;
}

Error Error::JobNameExistError() {
auto error = std::make_shared<cfg::ErrorProto>();
error->mutable_job_name_exist_error();
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Error final {
static Error DeviceTagNotFoundError();
static Error ValueError(const std::string& error_summary);
static Error IndexError();
static Error TimeoutError();
static Error JobNameExistError();
static Error JobNameEmptyError();
static Error JobNameNotEqualError();
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/common/error.proto
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ message ValueError {}

message IndexError {}

message TimeoutError {}

message ErrorProto {
optional string error_summary = 1 [default = ""];
optional string msg = 2 [default = ""];
Expand All @@ -148,6 +150,7 @@ message ErrorProto {
DeviceTagNotFoundError device_tag_not_found_error = 26;
ValueError value_error = 27;
IndexError index_error = 28;
TimeoutError timeout_error = 29;
JobNameExistError job_name_exist_error = 100;
JobNameEmptyError job_name_empty_error = 101;
JobNameNotEqualError job_name_not_equal_error = 102;
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/common/exception.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class Exception : public std::exception {
OF_PP_MAKE_TUPLE_SEQ(CompileOptionWrong) \
OF_PP_MAKE_TUPLE_SEQ(Value) \
OF_PP_MAKE_TUPLE_SEQ(Index) \
OF_PP_MAKE_TUPLE_SEQ(Timeout) \
OF_PP_MAKE_TUPLE_SEQ(InputDeviceNotMatch)

#define DEFINE_EXCEPTION_CLASS(cls) \
Expand Down