Skip to content

Commit

Permalink
bugfix: DeviceId4ParallelId -> MachineId4ParallelId
Browse files Browse the repository at this point in the history
  • Loading branch information
lixinqi committed Aug 5, 2021
1 parent ab450a5 commit ad5c97c
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include <utility>
#include "oneflow/core/common/constant.h"
#include "oneflow/core/common/cached_caller.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/op_interpreter/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/core/framework/op_interpreter/boxing/eager_boxing_interpreter_util.h"
Expand Down Expand Up @@ -68,12 +69,10 @@ Maybe<EagerBoxingInterpreter> GetOneDimNcclCollectiveEagerBoxingInterpreter(
out_parallel_distribution->sbp_parallel(0))));
}

} // namespace

Maybe<EagerBoxingInterpreter> EagerBoxingInterpreterManager::GetEagerBoxingInterpreter(
Maybe<EagerBoxingInterpreter> GetBoxingInterpreter(
Symbol<cfg::ParallelDistribution> in_parallel_distribution,
Symbol<cfg::ParallelDistribution> out_parallel_distribution,
Symbol<ParallelDesc> in_parallel_desc, Symbol<ParallelDesc> out_parallel_desc) const {
Symbol<ParallelDesc> in_parallel_desc, Symbol<ParallelDesc> out_parallel_desc) {
if (in_parallel_distribution == out_parallel_distribution
&& in_parallel_desc == out_parallel_desc) {
static std::shared_ptr<EagerBoxingInterpreter> identity_boxing_interpreter =
Expand All @@ -82,14 +81,13 @@ Maybe<EagerBoxingInterpreter> EagerBoxingInterpreterManager::GetEagerBoxingInter
}
if (in_parallel_distribution->sbp_parallel_size() == 1
&& out_parallel_distribution->sbp_parallel_size() == 1) {
if (EagerBoxingInterpreterUtil::IsPlacementEqual(in_parallel_desc, out_parallel_desc)) {
if (in_parallel_desc == out_parallel_desc) {
if (EagerBoxingInterpreterUtil::IsBoxingB2P(in_parallel_distribution->sbp_parallel(0),
out_parallel_distribution->sbp_parallel(0))) {
std::shared_ptr<EagerBoxingInterpreter> naive_bp_boxing_interpreter =
std::make_shared<NaiveB2PBoxingInterpreter>();
return naive_bp_boxing_interpreter;
}
if (EagerBoxingInterpreterUtil::IsDeviceTypeGPU(in_parallel_desc)) {
} else if (in_parallel_desc->device_type() == DeviceType::kGPU) {
return GetOneDimNcclCollectiveEagerBoxingInterpreter(in_parallel_distribution,
out_parallel_distribution);
} else {
Expand All @@ -103,6 +101,18 @@ Maybe<EagerBoxingInterpreter> EagerBoxingInterpreterManager::GetEagerBoxingInter
}
}

auto* CachedGetBoxingInterpreter = THREAD_LOCAL_CACHED(&GetBoxingInterpreter);

} // namespace

Maybe<EagerBoxingInterpreter> EagerBoxingInterpreterManager::GetEagerBoxingInterpreter(
Symbol<cfg::ParallelDistribution> in_parallel_distribution,
Symbol<cfg::ParallelDistribution> out_parallel_distribution,
Symbol<ParallelDesc> in_parallel_desc, Symbol<ParallelDesc> out_parallel_desc) const {
return CachedGetBoxingInterpreter(in_parallel_distribution, out_parallel_distribution,
in_parallel_desc, out_parallel_desc);
}

COMMAND(Global<EagerBoxingInterpreterManager>::SetAllocated(new EagerBoxingInterpreterManager()));

} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ limitations under the License.

namespace oneflow {

bool EagerBoxingInterpreterUtil::IsPlacementEqual(Symbol<ParallelDesc> src,
Symbol<ParallelDesc> dst) {
return src == dst;
}

bool EagerBoxingInterpreterUtil::IsDeviceTypeGPU(Symbol<ParallelDesc> parallel_desc) {
return parallel_desc->device_type() == DeviceType::kGPU;
}

bool EagerBoxingInterpreterUtil::IsBoxingS2B(const cfg::SbpParallel& src,
const cfg::SbpParallel& dst) {
return src.has_split_parallel() && dst.has_broadcast_parallel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ limitations under the License.
namespace oneflow {

struct EagerBoxingInterpreterUtil {
static bool IsPlacementEqual(Symbol<ParallelDesc> src, Symbol<ParallelDesc> dst);
static bool IsDeviceTypeGPU(Symbol<ParallelDesc> parallel_desc);
static bool IsBoxingS2B(const cfg::SbpParallel& src, const cfg::SbpParallel& dst);
static bool IsBoxingS2P(const cfg::SbpParallel& src, const cfg::SbpParallel& dst);
static bool IsBoxingP2S(const cfg::SbpParallel& src, const cfg::SbpParallel& dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Maybe<one::Tensor> NaiveB2PBoxingInterpreter::InterpretImpl(
Symbol<cfg::ParallelDistribution> out_parallel_distribution,
Symbol<ParallelDesc> in_parallel_desc, Symbol<ParallelDesc> out_parallel_desc) const {
CHECK_EQ_OR_RETURN(in_parallel_desc, out_parallel_desc);
int64_t root = JUST(in_parallel_desc->DeviceId4ParallelId(0));
int64_t root = JUST(in_parallel_desc->MachineId4ParallelId(0));
if (root == GlobalProcessCtx::LocalRank()) {
std::string device_type = Device::DeviceType4ParallelDesc(in_parallel_desc->device_tag());
return JUST(one::functional::Copy(input, device_type, root));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "oneflow/core/autograd/autograd_mode.h"
#include "oneflow/core/framework/op_interpreter/boxing/eager_boxing_interpreter_mgr.h"
#include "oneflow/user/kernels/stateful_local_opkernel.h"
#include "oneflow/core/framework/tensor_rpc_util.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -61,6 +62,26 @@ std::string GetDynamicOpConsistentFailedDebugString(const UserOpExpr& user_op_ex
return ss.str();
}

namespace {

Maybe<Tensor> GetBoxingOutput(const std::shared_ptr<Tensor>& input,
Symbol<cfg::ParallelDistribution> parallel_distribution) {
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(*input));
// Eager boxing
const auto& boxing_interpreter =
JUST(Global<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
JUST(input->parallel_distribution()), parallel_distribution, JUST(input->parallel_desc()),
JUST(input->parallel_desc())));
const auto& output = JUST(boxing_interpreter->Interpret(
input, JUST(input->parallel_distribution()), parallel_distribution,
JUST(input->parallel_desc()), JUST(input->parallel_desc())));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, RpcUtil::TimeoutSeconds()));
JUST(ctx->Check());
return output;
}

} // namespace

} // namespace

Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
Expand All @@ -87,25 +108,19 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
outputs->at(i).reset(new ConsistentTensor(tensor_impl));
}
// Do nothing if the `parallel_desc` doesn't cover current ProcessCtx.
if (!device) { return Maybe<void>::Ok(); }
if (!parallel_id.has_value()) { return Maybe<void>::Ok(); }
// Run instruction LocalCallOpKernel
const auto& kernel = JUST(user_op_expr.MutKernel4Device(*device));
CHECK_EQ_OR_RETURN(kernel->output_tuple_indexes4mut2_obns().size(), 0)
<< Error::Unimplemented() << GetDynamicOpConsistentFailedDebugString(user_op_expr, *kernel);
std::shared_ptr<EagerBlobObjectList> input_eager_blob_objects =
std::make_shared<EagerBlobObjectList>(inputs.size());
for (int i = 0; i < inputs.size(); ++i) {
// Eager boxing
const auto& boxing_interpreter =
JUST(Global<EagerBoxingInterpreterManager>::Get()->GetEagerBoxingInterpreter(
JUST(inputs.at(i)->parallel_distribution()),
result->input_parallel_distributions().at(i), JUST(inputs.at(i)->parallel_desc()),
JUST(inputs.at(i)->parallel_desc())));
const auto& boxing_output = JUST(boxing_interpreter->Interpret(
inputs.at(i), JUST(inputs.at(i)->parallel_distribution()),
result->input_parallel_distributions().at(i), JUST(inputs.at(i)->parallel_desc()),
JUST(inputs.at(i)->parallel_desc())));
const auto& local_tensor = JUST(boxing_output->cur_rank_phy_tensor());
std::shared_ptr<Tensor> input = inputs.at(i);
if (result->input_parallel_distributions().at(i) != JUST(input->parallel_distribution())) {
input = JUST(GetBoxingOutput(input, result->input_parallel_distributions().at(i)));
}
const auto& local_tensor = JUST(input->cur_rank_phy_tensor());
input_eager_blob_objects->at(i) = JUST(local_tensor->eager_blob_object());
}
std::shared_ptr<EagerBlobObjectList> output_eager_blob_objects =
Expand Down
12 changes: 4 additions & 8 deletions oneflow/core/functional/impl/consistent_cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/rank_group_scope.h"
#include "oneflow/core/framework/rpc_token.h"
#include "oneflow/core/framework/rpc_util.h"
#include "oneflow/core/common/flat_shape.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/tensor_rpc_util.h"
#include "oneflow/core/framework/rpc_util.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -145,16 +143,14 @@ Maybe<one::UserOpExpr> FindOrCreatParallelDistributionOpExpr(
Maybe<Tensor> ConsistentToConsistent(const std::shared_ptr<Tensor>& x,
Symbol<ParallelDesc> parallel_desc,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_parallels) {
const auto& ctx = JUST(LaunchTensorMetaConsistencyCheck(*x));
JUST(RpcUtil::WaitUntilDoneOrTimeout(*ctx, RpcUtil::TimeoutSeconds()));
JUST(ctx->Check());

This comment has been minimized.

Copy link
@clackhan

clackhan Aug 5, 2021

Contributor

这里不需要同步吗?

This comment has been minimized.

Copy link
@lixinqi

lixinqi Aug 5, 2021

Author Contributor

移到了eager_consistent_op_interpreter里去了。那才是对的位置。

const auto& consistent_tensor = std::dynamic_pointer_cast<ConsistentTensor>(x);
CHECK_NOTNULL_OR_RETURN(consistent_tensor) << "consistent tensors supported only";
CHECK_OR_RETURN(consistent_tensor->is_eager()) << "eager tensors supported only";
const auto& parallel_distribution_cast_op_expr =
JUST(FindOrCreatParallelDistributionOpExpr(sbp_parallels));
return JUST(OpInterpUtil::Dispatch<one::Tensor>(*parallel_distribution_cast_op_expr,
{consistent_tensor}));
const auto& ret = JUST(OpInterpUtil::Dispatch<one::Tensor>(*parallel_distribution_cast_op_expr,
{consistent_tensor}));
return ret;
}

Maybe<Tensor> LocalToConsistent(const std::shared_ptr<Tensor>& x,
Expand Down

0 comments on commit ad5c97c

Please sign in to comment.