-
Notifications
You must be signed in to change notification settings - Fork 656
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support tensor.to()/to_local() (#5271)
* support_tensor_to/to_local * export consistent_tensor.to_local() * refine code * export tensor.to()... * refine code * refine code * optimize code * refine code * refine * back up * add tensor.to func * make of_format * remove to in pyTensor * sync gpu data * refine * refine * refine * refine * refine * refine * refine * refine * refine * backup * refine * rebase * check in gen py * merge master and fix bugs * address pr comments * address pr comments * auto format by CI * remove boxing * refine * Fix optional * remove to in tensor.cpp * update * Support symbol placement type in functional. * add sbp and sbp list arg * refine * use functional * refactor CastConsistentOpExpr * to_consistent(flow.B) backward * Cache op expr * add EagerNcclOpKernelState * refine * refine * refine * refine * refine * refine * minor fix * capture OpInterpContext * unimplemented apply * add GetNdSbp * add mutex * refine * merge EagerConsistentTensorImpl::NewWithPhyTensor and EagerConsistentTensorImpl::NewWithoutPhyTensor into EagerConsistentTensorImpl::New * rename functiona SyncData to SyncMetaAndData * of_format * add to_local to pybind * add placement_sbp_util * minor fix * sync shape and data when tensor_to_local * fix rpc_token bugs * refactor AsyncRpcCtx * set logical_shape correctly * simplify implementation of consistent_tensor.to_local * initialize rpc_token with zero * refactor grad functions of to_consistent/to_local * reformat and address pr comment * reformat * refactor eager_nccl_reduce lernel Co-authored-by: tsai <jackalcooper@gmail.com> Co-authored-by: Xinqi Li <lixinqi0703106@163.com> Co-authored-by: Li Xinqi <lixinqi2010@gmail.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: hjchen2 <chenhoujiangcug@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
- Loading branch information
1 parent
c8b6d39
commit a72c21d
Showing
45 changed files
with
1,573 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
oneflow/core/autograd/gradient_funcs/consistent_cast.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/op_expr_grad_function.h" | ||
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" | ||
#include "oneflow/core/framework/op_expr.h" | ||
#include "oneflow/core/framework/op_expr_helper.h" | ||
#include "oneflow/core/framework/nd_sbp.h" | ||
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
struct CastConsistentOpExprInterpState : public OpExprInterpState { | ||
Symbol<ParallelDesc> parallel_desc; | ||
Symbol<cfg::ParallelDistribution> parallel_distribution; | ||
std::shared_ptr<const Shape> shape; | ||
}; | ||
|
||
class CastToConsistent : public OpExprGradFunction<CastConsistentOpExprInterpState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { | ||
const auto* fw_op_expr = dynamic_cast<const CastToConsistentOpExpr*>(&op); | ||
CHECK_NOTNULL_OR_RETURN(fw_op_expr); | ||
const std::string& op_name = fw_op_expr->op_name(); | ||
grad_op_ = JUST(one::CastFromConsistentOpExpr::New(GradientOpName(op_name))); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Capture(CastConsistentOpExprInterpState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, | ||
const OpExprInterpContext& interp_ctx) const override { | ||
ctx->parallel_desc = JUST(interp_ctx.parallel_desc.value()); | ||
ctx->parallel_distribution = JUST(interp_ctx.parallel_distribution.value()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const CastConsistentOpExprInterpState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
const auto& out_grad = out_grads.at(0); | ||
CHECK_OR_RETURN(out_grad->is_consistent()); | ||
const auto& bw_parallel_distribution = JUST(out_grad->parallel_distribution()); | ||
const auto& dual_parallel_distribution = JUST(GetDualNdSbp(ctx->parallel_distribution)); | ||
CHECK_OR_RETURN(bw_parallel_distribution == dual_parallel_distribution); | ||
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grads.at(0)})); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
private: | ||
std::shared_ptr<OpExpr> grad_op_; | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_to_consistent", CastToConsistent); | ||
|
||
class CastFromConsistent : public OpExprGradFunction<CastConsistentOpExprInterpState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { | ||
const auto* fw_op_expr = dynamic_cast<const CastFromConsistentOpExpr*>(&op); | ||
CHECK_NOTNULL_OR_RETURN(fw_op_expr); | ||
const std::string& op_name = fw_op_expr->op_name(); | ||
grad_op_ = JUST(one::CastToConsistentOpExpr::New(GradientOpName(op_name))); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Capture(CastConsistentOpExprInterpState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, const AttrMap& attrs) const override { | ||
const auto& input = inputs.at(0); | ||
CHECK_OR_RETURN(input->is_consistent()); | ||
ctx->parallel_desc = JUST(input->parallel_desc()); | ||
ctx->parallel_distribution = JUST(input->parallel_distribution()); | ||
ctx->shape = input->shape(); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const CastConsistentOpExprInterpState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
const auto& dual_parallel_distribution = JUST(GetDualNdSbp(ctx->parallel_distribution)); | ||
MutableAttrMap attrs; | ||
JUST(attrs.SetAttr<Shape>("shape", *ctx->shape)); | ||
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>( | ||
*grad_op_, {out_grads.at(0)}, | ||
OpExprInterpContext(attrs, ctx->parallel_desc, dual_parallel_distribution))); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
private: | ||
std::shared_ptr<OpExpr> grad_op_; | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("cast_from_consistent", CastFromConsistent); | ||
|
||
} // namespace one | ||
} // namespace oneflow |
88 changes: 88 additions & 0 deletions
88
oneflow/core/autograd/gradient_funcs/eager_nccl_broadcast.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/id_util.h" | ||
#include "oneflow/core/framework/op_builder.h" | ||
#include "oneflow/core/framework/op_expr_grad_function.h" | ||
#include "oneflow/core/framework/device.h" | ||
#include "oneflow/core/framework/op_builder.h" | ||
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" | ||
#include "oneflow/core/framework/op_expr.h" | ||
#include "oneflow/core/framework/op_expr_helper.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace one { | ||
|
||
namespace { | ||
|
||
Maybe<one::UserOpExpr> EagerNcclReduce(Symbol<ParallelDesc> parallel_desc, int64_t root) { | ||
return one::OpBuilder("eager_nccl_reduce", *JUST(UniqueStr("eager_nccl_reduce"))) | ||
.Input("in") | ||
.Output("out") | ||
.Attr<std::string>("parallel_conf", PbMessage2TxtString(parallel_desc->parallel_conf())) | ||
.Attr<int64_t>("root", root) | ||
.Build(); | ||
} | ||
|
||
Maybe<one::UserOpExpr> FindOrCreatEagerNcclReduceOpExpr(Symbol<ParallelDesc> parallel_desc, | ||
int64_t root) { | ||
thread_local HashMap<std::pair<Symbol<ParallelDesc>, int64_t>, std::shared_ptr<one::UserOpExpr>> | ||
parallel_desc_and_root_device2eager_nccl_reduce; | ||
const auto& key = std::make_pair(parallel_desc, root); | ||
auto iter = parallel_desc_and_root_device2eager_nccl_reduce.find(key); | ||
if (iter == parallel_desc_and_root_device2eager_nccl_reduce.end()) { | ||
std::shared_ptr<UserOpExpr> op_expr = JUST(EagerNcclReduce(parallel_desc, root)); | ||
iter = parallel_desc_and_root_device2eager_nccl_reduce.emplace(key, op_expr).first; | ||
} | ||
return iter->second; | ||
} | ||
|
||
} // namespace | ||
|
||
struct EagerNcclBroadcastOpExprInterpState : public OpExprInterpState { | ||
Symbol<ParallelDesc> parallel_desc; | ||
int64_t root; | ||
}; | ||
|
||
class EagerNcclBroadcast : public OpExprGradFunction<EagerNcclBroadcastOpExprInterpState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { | ||
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); | ||
CHECK_NOTNULL_OR_RETURN(fw_op_expr); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Capture(EagerNcclBroadcastOpExprInterpState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, | ||
const OpExprInterpContext& interp_ctx) const override { | ||
ctx->root = JUST(interp_ctx.attrs.GetAttr<int64_t>("root")); | ||
ctx->parallel_desc = JUST(interp_ctx.parallel_desc.value()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const EagerNcclBroadcastOpExprInterpState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
const auto& grad_op = JUST(FindOrCreatEagerNcclReduceOpExpr(ctx->parallel_desc, ctx->root)); | ||
in_grads->resize(1); | ||
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op, {out_grads.at(0)})); | ||
return Maybe<void>::Ok(); | ||
} | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("eager_nccl_broadcast", EagerNcclBroadcast); | ||
|
||
} // namespace one | ||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/common/optional.h" | ||
#include "oneflow/core/common/util.h" | ||
|
||
namespace oneflow { | ||
namespace test { | ||
|
||
TEST(Optional, copy_constructor) { | ||
Optional<int64_t> a(0); | ||
std::vector<Optional<int64_t>> vec; | ||
vec.push_back(a); | ||
ASSERT_TRUE(vec[0].has_value()); | ||
int64_t val = CHECK_JUST(vec[0].value()); | ||
ASSERT_EQ(val, 0); | ||
} | ||
|
||
TEST(Optional, move_constructor) { | ||
Optional<int64_t> a(0); | ||
std::map<int64_t, Optional<int64_t>> map; | ||
map.emplace(0, a); | ||
ASSERT_TRUE(map.at(0).has_value()); | ||
int64_t val = CHECK_JUST(map.at(0).value()); | ||
ASSERT_EQ(val, 0); | ||
} | ||
|
||
} // namespace test | ||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.