diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 01bfbfc1dfa..4d090c7919d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -489,6 +489,11 @@ jobs: docker run $extra_docker_args \ oneflow-test:$USER \ bash -c "bash ci/test/try_install.sh && bash ci/test/build_docs.sh" + - name: Query system status + if: ${{ failure() }} + run: | + nvidia-smi + docker ps - name: Remove container if: always() run: | diff --git a/cmake/third_party/glog.cmake b/cmake/third_party/glog.cmake index 36deb257c4e..fe1b6674338 100644 --- a/cmake/third_party/glog.cmake +++ b/cmake/third_party/glog.cmake @@ -18,9 +18,9 @@ else() if(BUILD_SHARED_LIBS) # Must use a shared lib with cpack version set(GLOG_VER 0.3.4) - if(${CMAKE_SHARED_LIBRARY_SUFFIX} STREQUAL ".dylib") + if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") set(GLOG_LIBRARY_NAMES libglog.${GLOG_VER}.dylib) - elseif(${CMAKE_SHARED_LIBRARY_SUFFIX} STREQUAL ".so") + elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") set(GLOG_LIBRARY_NAMES libglog.so.${GLOG_VER}) else() message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for glog") diff --git a/cmake/third_party/protobuf.cmake b/cmake/third_party/protobuf.cmake index 10e4b4cf259..68b6be9f945 100644 --- a/cmake/third_party/protobuf.cmake +++ b/cmake/third_party/protobuf.cmake @@ -19,9 +19,9 @@ else() set(PROTOBUF_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf) if(BUILD_SHARED_LIBS) set(PB_VER 3.9.2.0) - if(${CMAKE_SHARED_LIBRARY_SUFFIX} STREQUAL ".dylib") + if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") set(PROTOBUF_LIBRARY_NAMES libprotobuf.${PB_VER}.dylib) - elseif(${CMAKE_SHARED_LIBRARY_SUFFIX} STREQUAL ".so") + elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") set(PROTOBUF_LIBRARY_NAMES libprotobuf.so.${PB_VER}) else() message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for protobuf") diff --git a/cmake/third_party/zlib.cmake b/cmake/third_party/zlib.cmake index 77de2f0f82f..191dbd7b562 100644 --- a/cmake/third_party/zlib.cmake +++ b/cmake/third_party/zlib.cmake @@ -10,12 +10,20 @@ use_mirror(VARIABLE ZLIB_URL URL ${ZLIB_URL}) if(WIN32) set(ZLIB_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib) set(ZLIB_LIBRARY_NAMES zlibstaticd.lib) -elseif(APPLE AND ("${CMAKE_GENERATOR}" STREQUAL "Xcode")) - set(ZLIB_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib) - set(ZLIB_LIBRARY_NAMES libz.a) else() set(ZLIB_BUILD_LIBRARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib) - set(ZLIB_LIBRARY_NAMES libz.a) + if(BUILD_SHARED_LIBS) + set(Z_VER 1.2.8) + if("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".dylib") + set(ZLIB_LIBRARY_NAMES libz.${Z_VER}.dylib) + elseif("${CMAKE_SHARED_LIBRARY_SUFFIX}" STREQUAL ".so") + set(ZLIB_LIBRARY_NAMES libz.so.${Z_VER}) + else() + message(FATAL_ERROR "${CMAKE_SHARED_LIBRARY_SUFFIX} not support for zlib") + endif() + else() + set(ZLIB_LIBRARY_NAMES libz.a) + endif() endif() foreach(LIBRARY_NAME ${ZLIB_LIBRARY_NAMES}) diff --git a/oneflow/core/autograd/gradient_funcs/normalization.cpp b/oneflow/core/autograd/gradient_funcs/normalization.cpp index 46bec025154..e78c8cc94d3 100644 --- a/oneflow/core/autograd/gradient_funcs/normalization.cpp +++ b/oneflow/core/autograd/gradient_funcs/normalization.cpp @@ -26,6 +26,8 @@ namespace oneflow { namespace one { struct NormalizationGradInterpState : public OpExprInterpState { + int32_t axis; + float epsilon; bool is_training; }; @@ -42,16 +44,15 @@ class NormalizationGrad : public OpExprGradFunctionop_name(); op_trait_ = std::make_shared(op_name, fw_op_expr->proto()); - const float epsilon = JUST(op_trait_->GetAttr("epsilon")); - axis_ = JUST(op_trait_->GetAttr("axis")); // v1 = variance + eps - add_eps_op_ = JUST(op_expr_helper::ScalarAddOp(epsilon, GradientOpName(op_name + "_add_eps"))); + add_eps_op_ = + JUST(op_expr_helper::ScalarAddOp(/*epsilon=*/0, GradientOpName(op_name + "_add_eps"))); // v2 = rsqrt(v1) rsqrt_op_ = JUST(op_expr_helper::RsqrtOp(GradientOpName(op_name + "_rsqrt"))); // Normalization grad. normalization_grad_op_ = JUST(op_expr_helper::NormalizationGradOp( - axis_, epsilon, GradientOpName(op_name + "_norm_grad"))); + /*axis=*/-1, /*epsilon=*/0, GradientOpName(op_name + "_norm_grad"))); reshape_gamma_op_ = JUST(op_expr_helper::ReshapeOp(Shape{-1}, GradientOpName(op_name + "_reshape_gamma"))); @@ -68,6 +69,8 @@ class NormalizationGrad : public OpExprGradFunction Capture(NormalizationGradInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { + ctx->axis = JUST(op_trait_->GetAttr("axis", attrs)); + ctx->epsilon = JUST(op_trait_->GetAttr("epsilon", attrs)); ctx->is_training = JUST(op_trait_->GetAttr("training", attrs)); ctx->SaveTensorForBackward(inputs.at(0)); // x ctx->SaveTensorForBackward(inputs.at(3)); // gamma @@ -94,12 +97,19 @@ class NormalizationGrad : public OpExprGradFunctionSavedTensors().at(2); // moving_mean const auto& moving_variance = ctx->SavedTensors().at(3); // moving_variance - const auto& add_eps = JUST(OpInterpUtil::Dispatch(*add_eps_op_, {moving_variance})); + MutableAttrMap epsilon_attr; + JUST(epsilon_attr.SetAttr("epsilon", ctx->epsilon)); + const auto& add_eps = + JUST(OpInterpUtil::Dispatch(*add_eps_op_, {moving_variance}, epsilon_attr)); mean = moving_mean; inv_variance = JUST(OpInterpUtil::Dispatch(*rsqrt_op_, {add_eps})); } + + MutableAttrMap norm_grad_attr; + JUST(norm_grad_attr.SetAttr("axis", ctx->axis)); + JUST(norm_grad_attr.SetAttr("epsilon", ctx->epsilon)); const auto& results = JUST(OpInterpUtil::Dispatch( - *normalization_grad_op_, {x, y_grad, gamma, mean, inv_variance})); + *normalization_grad_op_, {x, y_grad, gamma, mean, inv_variance}, norm_grad_attr)); CHECK_EQ_OR_RETURN(results->size(), 3); // The normalization op has 5 inputs which are x, moving_mean, moving_variance, gamma and beta. in_grads->resize(5); @@ -112,10 +122,10 @@ class NormalizationGrad : public OpExprGradFunctionshape()->NumAxes(); ++i) { - if (i != axis_) { + if (i != ctx->axis) { dim_vec.push_back(1); } else { - dim_vec.push_back(x->shape()->At(axis_)); + dim_vec.push_back(x->shape()->At(ctx->axis)); } } MutableAttrMap shape_attr; @@ -142,7 +152,6 @@ class NormalizationGrad : public OpExprGradFunction op_trait_; - int32_t axis_; std::shared_ptr add_eps_op_; std::shared_ptr rsqrt_op_; std::shared_ptr normalization_grad_op_; diff --git a/oneflow/core/autograd/gradient_funcs/where.cpp b/oneflow/core/autograd/gradient_funcs/where.cpp new file mode 100644 index 00000000000..3c54b7eaa51 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/where.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_expr_helper.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" + +namespace oneflow { +namespace one { + +struct WhereInterpState : public OpExprInterpState { + bool requires_grad_x; + bool requires_grad_y; +}; + +class Where : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(WhereInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override; + Maybe Apply(const WhereInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; + std::shared_ptr zero_like_op_; + std::shared_ptr where_op_x_; + std::shared_ptr where_op_y_; +}; + +Maybe Where::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + const std::string& op_name = fw_op_expr->op_name(); + zero_like_op_ = JUST(op_expr_helper::ZeroLikeOp("zeros_like_" + GradientOpName(op_name))); + where_op_x_ = JUST(op_expr_helper::WhereOp("where_x_" + GradientOpName(op_name))); + where_op_y_ = JUST(op_expr_helper::WhereOp("where_y_" + GradientOpName(op_name))); + return Maybe::Ok(); +} + +Maybe Where::Capture(WhereInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ctx->requires_grad_x = inputs.at(1)->requires_grad(); + ctx->requires_grad_y = inputs.at(2)->requires_grad(); + if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->SaveTensorForBackward(inputs.at(0)); // condition + ctx->SaveTensorForBackward(inputs.at(1)); // x + return Maybe::Ok(); +} + +Maybe Where::Apply(const WhereInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if ((!ctx->requires_grad_x) && (!ctx->requires_grad_y)) { return Maybe::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + MutableAttrMap attrs; + const std::shared_ptr& condtion = ctx->SavedTensors().at(0); + const std::shared_ptr& x = ctx->SavedTensors().at(1); + + std::shared_ptr zero_out = + JUST(OpInterpUtil::Dispatch(*zero_like_op_, {x})); + in_grads->resize(3); + if (ctx->requires_grad_x) + in_grads->at(1) = + JUST(OpInterpUtil::Dispatch(*where_op_x_, {condtion, out_grads.at(0), zero_out})); + if (ctx->requires_grad_y) + in_grads->at(2) = + JUST(OpInterpUtil::Dispatch(*where_op_y_, {condtion, zero_out, out_grads.at(0)})); + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("where", Where); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/control/ctrl_client.cpp b/oneflow/core/control/ctrl_client.cpp index 34dfb8e0b5f..a4ef7ee2b73 100644 --- a/oneflow/core/control/ctrl_client.cpp +++ b/oneflow/core/control/ctrl_client.cpp @@ -23,13 +23,7 @@ namespace { } // namespace -GrpcCtrlClient::~GrpcCtrlClient() { - { - std::unique_lock lck(need_heartbeat_thread_stop_mtx_); - need_heartbeat_thread_stop_ = true; - } - heartbeat_thread_.join(); -} +GrpcCtrlClient::~GrpcCtrlClient() { StopHeartbeat(); } GrpcCtrlClient::GrpcCtrlClient(const ProcessCtx& process_ctx) : process_ctx_(process_ctx) { rpc_client_.ReserveStubsOfSize(process_ctx.ctrl_addr_size()); @@ -118,4 +112,14 @@ int32_t GrpcCtrlClient::IncreaseCount(const std::string& k, int32_t v) { void GrpcCtrlClient::EraseCount(const std::string& k) { rpc_client_.EraseCount(k); } +void GrpcCtrlClient::StopHeartbeat() { + bool already_stopped = false; + { + std::unique_lock lck(need_heartbeat_thread_stop_mtx_); + already_stopped = need_heartbeat_thread_stop_; + need_heartbeat_thread_stop_ = true; + } + if (!already_stopped) { heartbeat_thread_.join(); } +} + } // namespace oneflow diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/opkernel_instruction_type.cpp index 5a74ec1e2bc..ad68caceb46 100644 --- a/oneflow/core/eager/opkernel_instruction_type.cpp +++ b/oneflow/core/eager/opkernel_instruction_type.cpp @@ -443,9 +443,9 @@ Maybe GetSharedOpKernel(vm::Instruction* instruction, DeviceType device_type struct LocalCallOpKernelUtil final { static inline Maybe Infer(vm::Instruction* instruction) { auto* operand = JUST(GetLocalCallOpKernelPhyInstrOperand(instruction)); + operand->mut_opkernel()->composed_attrs_for_scheduler_thread()->ResetPrior(operand->attrs()); operand->set_user_opkernel( JUST(operand->mut_opkernel()->ChooseOpKernel(operand->inputs(), operand->outputs()))); - operand->mut_opkernel()->ResetDynamicOpAttrs(operand->attrs()); JUST(CheckOutputBlobObjectsMemCase(operand, instruction->stream())); JUST(InitOutputBlobs(operand)); JUST(InferTempStorageBlobDesc(operand)); @@ -522,7 +522,8 @@ struct LocalCallOpKernelUtil final { const auto& InferTmpSizeFn = operand->opkernel().GetInferTmpSizeFn(operand->user_opkernel()); auto* temp_blob_desc = operand->mut_opkernel()->mut_temp_blob_object()->mut_blob_desc(); CHECK_OR_RETURN(temp_blob_desc->data_type() == DataType::kChar); - one::LocalUserOpInferContext* op_infer_ctx = operand->opkernel().op_infer_ctx_for_thread_a(); + one::LocalUserOpInferContext* op_infer_ctx = + operand->opkernel().op_infer_ctx_for_scheduler_thread(); op_infer_ctx->Update(operand->inputs(), operand->outputs()); size_t temp_size = InferTmpSizeFn(op_infer_ctx); temp_blob_desc->mut_shape() = Shape({static_cast(temp_size)}); diff --git a/oneflow/core/framework/op_expr_helper.cpp b/oneflow/core/framework/op_expr_helper.cpp index 0fded2d3edf..407a5faedff 100644 --- a/oneflow/core/framework/op_expr_helper.cpp +++ b/oneflow/core/framework/op_expr_helper.cpp @@ -603,6 +603,16 @@ Maybe SplitLikeOp(const int n, const int64_t axis, const std::s .Build(); } +Maybe WhereOp() { return WhereOp(UniqueOpName("where")); } +Maybe WhereOp(const std::string& name) { + return one::OpBuilder("where", name) + .Input("condition") + .Input("x") + .Input("y") + .Output("out") + .Build(); +} + Maybe ExpandGradOp(const std::vector& out_shape, const std::vector& stride) { return ExpandGradOp(out_shape, stride, UniqueOpName("expand_grad")); diff --git a/oneflow/core/framework/op_expr_helper.h b/oneflow/core/framework/op_expr_helper.h index 14090caecb1..c9843ac842f 100644 --- a/oneflow/core/framework/op_expr_helper.h +++ b/oneflow/core/framework/op_expr_helper.h @@ -195,6 +195,9 @@ Maybe TransposeOp(const std::vector& perm, const std:: Maybe SplitLikeOp(const int n, const int64_t axis); Maybe SplitLikeOp(const int n, const int64_t axis, const std::string& name); +Maybe WhereOp(); +Maybe WhereOp(const std::string& name); + Maybe ExpandGradOp(const std::vector& out_shape, const std::vector& stride); Maybe ExpandGradOp(const std::vector& out_shape, diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index f4cf3b45c7d..209edc4be87 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -93,11 +93,11 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in output_eager_blob_objects->at(index)->set_is_shape_synced(false); } - kernel->ResetDynamicOpAttrs(attrs); + kernel->composed_attrs_for_main_thread()->ResetPrior(attrs); JUST(kernel->InferDataType(input_eager_blob_objects, output_eager_blob_objects, - kernel->op_infer_ctx_for_thread_b())); + kernel->op_infer_ctx_for_main_thread())); JUST(kernel->InferTensorDesc(input_eager_blob_objects, output_eager_blob_objects, - kernel->op_infer_ctx_for_thread_b())); + kernel->op_infer_ctx_for_main_thread())); const auto& instr_type_name = JUST(op_device->local_call_instruction_name()); JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { diff --git a/oneflow/core/rpc/include/grpc.h b/oneflow/core/rpc/include/grpc.h index 61f287e8abc..2bc12582da1 100644 --- a/oneflow/core/rpc/include/grpc.h +++ b/oneflow/core/rpc/include/grpc.h @@ -25,8 +25,8 @@ namespace oneflow { class GrpcCtrlClient final : public CtrlClient { public: OF_DISALLOW_COPY_AND_MOVE(GrpcCtrlClient); - GrpcCtrlClient(const ProcessCtx& process_ctx); - ~GrpcCtrlClient(); + explicit GrpcCtrlClient(const ProcessCtx& process_ctx); + ~GrpcCtrlClient() override; void Barrier(const std::string& barrier_name) override; void Barrier(const std::string& barrier_name, int32_t barrier_num) override; @@ -51,6 +51,7 @@ class GrpcCtrlClient final : public CtrlClient { void Clear() override; int32_t IncreaseCount(const std::string& k, int32_t v) override; void EraseCount(const std::string& k) override; + void StopHeartbeat(); private: const ProcessCtx& process_ctx() const { return process_ctx_; } diff --git a/oneflow/core/rpc/lib/grpc.cpp b/oneflow/core/rpc/lib/grpc.cpp index 6f8681a9a04..24478b0d37e 100644 --- a/oneflow/core/rpc/lib/grpc.cpp +++ b/oneflow/core/rpc/lib/grpc.cpp @@ -15,11 +15,10 @@ limitations under the License. */ #ifdef RPC_BACKEND_GRPC -#include "oneflow/core/control/ctrl_client.h" +#include "oneflow/core/rpc/include/grpc.h" #include "oneflow/core/control/ctrl_bootstrap.h" #include "oneflow/core/control/ctrl_server.h" #include "oneflow/core/job/env_desc.h" -#include "oneflow/core/rpc/include/grpc.h" namespace oneflow { @@ -57,6 +56,10 @@ Maybe GrpcRpcManager::CreateClient() { } GrpcRpcManager::~GrpcRpcManager() { + auto* grpc_client = dynamic_cast(Global::Get()); + CHECK_NOTNULL(grpc_client); + grpc_client->StopHeartbeat(); + OF_ENV_BARRIER(); Global::Delete(); CHECK_NOTNULL(Global::Get()); Global::Delete(); @@ -64,4 +67,4 @@ GrpcRpcManager::~GrpcRpcManager() { } // namespace oneflow -#endif // RPC_BACKEND_GPRC +#endif // RPC_BACKEND_GRPC diff --git a/oneflow/core/vm/instruction.msg.h b/oneflow/core/vm/instruction.msg.h index a3be1ea4a15..fc51c8a4d64 100644 --- a/oneflow/core/vm/instruction.msg.h +++ b/oneflow/core/vm/instruction.msg.h @@ -115,9 +115,13 @@ OBJECT_MSG_BEGIN(InstructionEdge); set_src_instruction(src_instruction); set_dst_instruction(dst_instruction); } + + // fields + OBJECT_MSG_DEFINE_PTR(Instruction, src_instruction); + OBJECT_MSG_DEFINE_PTR(Instruction, dst_instruction); // links - OBJECT_MSG_DEFINE_SKIPLIST_KEY(10, Instruction*, src_instruction); - OBJECT_MSG_DEFINE_SKIPLIST_KEY(10, Instruction*, dst_instruction); + OBJECT_MSG_DEFINE_LIST_LINK(src_instruction_link); + OBJECT_MSG_DEFINE_LIST_LINK(dst_instruction_link); OBJECT_MSG_END(InstructionEdge); // clang-format on @@ -229,8 +233,8 @@ OBJECT_MSG_BEGIN(Instruction); OBJECT_MSG_DEFINE_LIST_LINK(front_seq_infer_instr_link); OBJECT_MSG_DEFINE_LIST_LINK(front_seq_compute_instr_link); OBJECT_MSG_DEFINE_LIST_HEAD(CallbackMsg, callback_link, callback_list); - OBJECT_MSG_DEFINE_SKIPLIST_HEAD(InstructionEdge, src_instruction, in_edges); - OBJECT_MSG_DEFINE_SKIPLIST_HEAD(InstructionEdge, dst_instruction, out_edges); + OBJECT_MSG_DEFINE_LIST_HEAD(InstructionEdge, src_instruction_link, in_edges); + OBJECT_MSG_DEFINE_LIST_HEAD(InstructionEdge, dst_instruction_link, out_edges); OBJECT_MSG_DEFINE_SKIPLIST_HEAD(RwMutexedObjectAccess, mirrored_object_id, mirrored_object_id2access); OBJECT_MSG_DEFINE_LIST_HEAD(RwMutexedObjectAccess, instruction_access_link, access_list); OBJECT_MSG_END(Instruction); diff --git a/oneflow/core/vm/virtual_machine.cpp b/oneflow/core/vm/virtual_machine.cpp index ea894931e0e..ecbc4110cbe 100644 --- a/oneflow/core/vm/virtual_machine.cpp +++ b/oneflow/core/vm/virtual_machine.cpp @@ -380,9 +380,8 @@ void VirtualMachine::ConnectInstruction(Instruction* src_instruction, CHECK_NE(src_instruction, dst_instruction); auto edge = ObjectMsgPtr::NewFrom(mut_vm_thread_only_allocator(), src_instruction, dst_instruction); - bool src_inserted = src_instruction->mut_out_edges()->Insert(edge.Mutable()).second; - bool dst_inserted = dst_instruction->mut_in_edges()->Insert(edge.Mutable()).second; - CHECK_EQ(src_inserted, dst_inserted); + src_instruction->mut_out_edges()->PushBack(edge.Mutable()); + dst_instruction->mut_in_edges()->PushBack(edge.Mutable()); } void VirtualMachine::ConsumeMirroredObjects(Id2LogicalObject* id2logical_object, @@ -537,8 +536,8 @@ void VirtualMachine::TryMoveWaitingToReady(Instruction* instruction, ReadyList* const IsEdgeReadyT& IsEdgeReady) { auto* wait_instruction_list = mut_waiting_instruction_list(); auto* out_edges = instruction->mut_out_edges(); - OBJECT_MSG_SKIPLIST_FOR_EACH_PTR(out_edges, out_edge) { - Instruction* out_instruction = out_edge->dst_instruction(); + OBJECT_MSG_LIST_FOR_EACH_PTR(out_edges, out_edge) { + Instruction* out_instruction = out_edge->mut_dst_instruction(); if (!IsEdgeReady(out_instruction)) { continue; } out_edges->Erase(out_edge); out_instruction->mut_in_edges()->Erase(out_edge); @@ -582,6 +581,11 @@ void VirtualMachine::Receive(InstructionMsgList* compute_instr_msg_list) { } compute_instr_msg_list->MoveToDstBack(compute_instr_msg, &new_instr_msg_list); } + static const int64_t kHighWaterMark = 500; + static const int64_t kLowWaterMark = 200; + if (*mut_flying_instruction_cnt() > kHighWaterMark) { + while (*mut_flying_instruction_cnt() > kLowWaterMark) {} + } mut_pending_msg_list()->MoveFrom(&new_instr_msg_list); } @@ -656,6 +660,9 @@ void VirtualMachine::Schedule() { new_instruction_list.MoveTo(waiting_instruction_list); } DispatchAndPrescheduleInstructions(ready_instruction_list); + *mut_flying_instruction_cnt() = mut_waiting_instruction_list()->size() + + mut_ready_instruction_list()->size() + + mutable_vm_stat_running_instruction_list()->size(); } bool VirtualMachine::Empty() const { diff --git a/oneflow/core/vm/virtual_machine.msg.h b/oneflow/core/vm/virtual_machine.msg.h index 4673697bf6b..3d91bfa7de8 100644 --- a/oneflow/core/vm/virtual_machine.msg.h +++ b/oneflow/core/vm/virtual_machine.msg.h @@ -55,6 +55,7 @@ OBJECT_MSG_BEGIN(VirtualMachine); // fields OBJECT_MSG_DEFINE_OPTIONAL(VmResourceDesc, vm_resource_desc); OBJECT_MSG_DEFINE_STRUCT(Range, machine_id_range); + OBJECT_MSG_DEFINE_STRUCT(std::atomic, flying_instruction_cnt); OBJECT_MSG_DEFINE_PTR(ObjectMsgAllocator, vm_thread_only_allocator); // heads diff --git a/oneflow/init.py b/oneflow/init.py index 52d95975e13..92090df08f8 100644 --- a/oneflow/init.py +++ b/oneflow/init.py @@ -68,9 +68,28 @@ env_util.init_default_physical_env() del env_util + +# capture oneflow methods so that they can be still accessed after `del oneflow` +def _SyncOnMasterFn(get_rank, sync): + def SyncOnMaster(): + if get_rank() == 0: + sync() + + return SyncOnMaster + + atexit.register(oneflow._oneflow_internal.SetShuttingDown) atexit.register(oneflow._oneflow_internal.DestroyEnv) atexit.register(oneflow.python.framework.session_context.TryCloseDefaultSession) +# Global::Get(), used by vm in background thread, +# will be set to nullptr by TryCloseDefaultSession, +# so sync vm in advance to avoid data race +atexit.register( + _SyncOnMasterFn( + oneflow.python.framework.distribute.get_rank, + oneflow._oneflow_internal.eager.single_client.Sync, + ) +) del atexit import sys diff --git a/oneflow/python/nn/modules/where.py b/oneflow/python/nn/modules/where.py index 906aef34037..393250edff0 100644 --- a/oneflow/python/nn/modules/where.py +++ b/oneflow/python/nn/modules/where.py @@ -34,9 +34,22 @@ def __init__(self) -> None: def forward(self, condition, x, y): assert condition.dtype == flow.int32 or condition.dtype == flow.int8 if isinstance(x, int) or isinstance(x, float): - x = flow.Tensor([float(x)], dtype=flow.float32) + x = flow.Tensor( + [float(x)], + dtype=flow.float32, + device=flow.device(condition.device.type), + ) if isinstance(y, int) or isinstance(y, float): - y = flow.Tensor([float(y)], dtype=flow.float32) + y = flow.Tensor( + [float(y)], + dtype=flow.float32, + device=flow.device(condition.device.type), + ) + + assert ( + condition.device.type == x.device.type + and condition.device.type == y.device.type + ) broadcast_cond = condition broadcast_x = x broadcast_y = y @@ -59,6 +72,8 @@ def forward(self, condition, x, y): broadcast_like_tensor = flow.experimental.zeros( tuple(broadcast_like_shape), dtype=flow.float32 ) + broadcast_like_tensor = broadcast_like_tensor.to(x.device.type) + broadcast_like_tensor.requires_grad = x.requires_grad or y.requires_grad if len(broadcast_condition_axes) != 0: condition = flow.experimental.cast(condition, flow.float32) diff --git a/oneflow/python/test/modules/test_where.py b/oneflow/python/test/modules/test_where.py index 00a3fa120eb..3b77dac6aac 100644 --- a/oneflow/python/test/modules/test_where.py +++ b/oneflow/python/test/modules/test_where.py @@ -14,9 +14,162 @@ limitations under the License. """ import unittest +from collections import OrderedDict import numpy as np + import oneflow.experimental as flow +from test_util import GenArgList + + +def _test_where(test_case, device): + x = flow.Tensor( + np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), + dtype=flow.float32, + device=flow.device(device), + ) + y = flow.Tensor( + np.ones(shape=(3, 2)), dtype=flow.float32, device=flow.device(device) + ) + condition = flow.Tensor( + np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device) + ) + of_out = flow.where(condition, x, y) + np_out = np.array([[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_where_broadcast(test_case, device): + x = flow.Tensor( + np.array([[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), + dtype=flow.float32, + device=flow.device(device), + ) + y = flow.Tensor( + np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device) + ) + condition = flow.Tensor( + np.array([[[0, 1], [1, 0], [1, 0]]]), + dtype=flow.int32, + device=flow.device(device), + ) + of_out = flow.where(condition, x, y) + np_out = np.array( + [ + [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], + [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], + [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], + ] + ) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_where_scalar(test_case, device): + x = 0.5 + y = 2.0 + condition = flow.Tensor(np.array([1]), dtype=flow.int32) + of_out = flow.where(condition, x, y) + np_out = np.array([0.5]) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_where_dim4(test_case, device): + x = flow.Tensor( + np.array([[[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]]), + dtype=flow.float32, + device=flow.device(device), + ) + y = flow.Tensor( + np.ones(shape=(1, 1, 3, 2)), dtype=flow.float32, device=flow.device(device) + ) + condition = flow.Tensor( + np.array([[[[0, 1], [1, 0], [1, 0]]]]), + dtype=flow.int32, + device=flow.device(device), + ) + of_out = flow.where(condition, x, y) + np_out = np.array([[[[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]]]) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out, 1e-5, 1e-5)) + + +def _test_where_backward(test_case, device): + x = flow.Tensor( + np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + y = flow.Tensor( + np.ones(shape=(3, 2)), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + condition = flow.Tensor( + np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32, device=flow.device(device) + ) + of_out = flow.where(condition, x, y) + of_out = of_out.sum() + of_out.backward() + test_case.assertTrue( + np.allclose(x.grad.numpy(), condition.numpy() == 1, 1e-5, 1e-5) + ) + test_case.assertTrue( + np.allclose(y.grad.numpy(), condition.numpy() == 0, 1e-5, 1e-5) + ) + + +def _test_where_broadcast_backward(test_case, device): + x = flow.Tensor( + np.array([[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + y = flow.Tensor( + np.ones(shape=(3, 3, 2)), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + condition = flow.Tensor( + np.array([[[0, 1], [1, 0], [1, 0]]]), + dtype=flow.int32, + device=flow.device(device), + ) + of_out = flow.where(condition, x, y) + of_out = of_out.sum() + of_out.backward() + x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]] + test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-5, 1e-5)) + y_grad = [ + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0]], + ] + test_case.assertTrue(np.allclose(y.grad.numpy(), y_grad, 1e-5, 1e-5)) + + +def _test_where_broadcast_x_backward(test_case, device): + x = flow.Tensor( + np.array([[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), + dtype=flow.float32, + device=flow.device(device), + requires_grad=True, + ) + y = flow.Tensor( + np.ones(shape=(3, 3, 2)), dtype=flow.float32, device=flow.device(device), + ) + condition = flow.Tensor( + np.array([[[0, 1], [1, 0], [1, 0]]]), + dtype=flow.int32, + device=flow.device(device), + ) + of_out = flow.where(condition, x, y) + of_out = of_out.sum() + of_out.backward() + x_grad = [[[0.0, 3.0], [3.0, 0.0], [3.0, 0.0]]] + test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-5, 1e-5)) @unittest.skipIf( @@ -25,64 +178,19 @@ ) class TestWhere(flow.unittest.TestCase): def test_where(test_case): - x = flow.Tensor( - np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), - dtype=flow.float32, - ) - y = flow.Tensor(np.ones(shape=(3, 2)), dtype=flow.float32) - condition = flow.Tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) - of_out = flow.where(condition, x, y) - np_out = np.array([[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) - - def test_tensor_where(test_case): - x = flow.Tensor( - np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), - dtype=flow.float32, - ) - y = flow.Tensor(np.ones(shape=(3, 2)), dtype=flow.float32) - condition = flow.Tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) - of_out = condition.where(x, y) - np_out = np.array([[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) - - def test_where_broadcast(test_case): - x = flow.Tensor( - np.array([[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]), - dtype=flow.float32, - ) - y = flow.Tensor(np.ones(shape=(3, 3, 2)), dtype=flow.float32) - condition = flow.Tensor(np.array([[[0, 1], [1, 0], [1, 0]]]), dtype=flow.int32) - of_out = flow.where(condition, x, y) - np_out = np.array( - [ - [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], - [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], - [[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]], - ] - ) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) - - def test_where_scalar(test_case): - x = 0.5 - y = 2.0 - condition = flow.Tensor(np.array([1]), dtype=flow.int32) - of_out = flow.where(condition, x, y) - np_out = np.array([0.5]) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) - - def test_where_dim4(test_case): - x = flow.Tensor( - np.array([[[[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]]), - dtype=flow.float32, - ) - y = flow.Tensor(np.ones(shape=(1, 1, 3, 2)), dtype=flow.float32) - condition = flow.Tensor( - np.array([[[[0, 1], [1, 0], [1, 0]]]]), dtype=flow.int32 - ) - of_out = flow.where(condition, x, y) - np_out = np.array([[[[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]]]) - test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + arg_dict = OrderedDict() + arg_dict["test_fun"] = [ + _test_where, + _test_where_broadcast, + _test_where_scalar, + _test_where_dim4, + _test_where_backward, + _test_where_broadcast_backward, + _test_where_broadcast_x_backward, + ] + arg_dict["device"] = ["cpu", "cuda"] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) if __name__ == "__main__": diff --git a/oneflow/python/test/tensor/test_tensor.py b/oneflow/python/test/tensor/test_tensor.py index d6f5fa74441..95f71931b49 100644 --- a/oneflow/python/test/tensor/test_tensor.py +++ b/oneflow/python/test/tensor/test_tensor.py @@ -560,6 +560,21 @@ def test_construct_small_tensor(test_case): test_case.assertEqual(tensor.dtype, flow.float32) test_case.assertTrue(np.allclose(tensor.numpy(), np.array(scalar), 1e-4, 1e-4)) + @unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + "numpy doesn't work in lazy mode", + ) + def test_tensor_where(test_case): + x = flow.Tensor( + np.array([[-0.4620, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]), + dtype=flow.float32, + ) + y = flow.Tensor(np.ones(shape=(3, 2)), dtype=flow.float32) + condition = flow.Tensor(np.array([[0, 1], [1, 0], [1, 0]]), dtype=flow.int32) + of_out = condition.where(x, y) + np_out = np.array([[1.0000, 0.3139], [0.3898, 1.0000], [0.0478, 1.0000]]) + test_case.assertTrue(np.allclose(of_out.numpy(), np_out)) + if __name__ == "__main__": unittest.main() diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index 4523a867781..6466e9c428e 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -344,7 +344,8 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr opkernel->op_conf_ = op_conf; opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf)); opkernel->device_ = device; - opkernel->composed_attrs_.reset(new ComposedAttrMap(base_attrs)); + opkernel->composed_attrs_for_scheduler_thread_.reset(new ComposedAttrMap(base_attrs)); + opkernel->composed_attrs_for_main_thread_.reset(new ComposedAttrMap(base_attrs)); opkernel->input_arg_tuple_ = input_arg_tuple; opkernel->output_arg_tuple_ = output_arg_tuple; opkernel->need_check_mem_case_ = true; @@ -355,17 +356,20 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr const std::string& device_tag = op_conf->device_tag(); const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get(); - const ComposedAttrMap* composed_attrs = opkernel->composed_attrs_.get(); - opkernel->op_infer_ctx_for_thread_a_.reset( - new LocalUserOpInferContext(user_op_conf, composed_attrs, input_arg_tuple, output_arg_tuple)); - opkernel->op_infer_ctx_for_thread_b_.reset( - new LocalUserOpInferContext(user_op_conf, composed_attrs, input_arg_tuple, output_arg_tuple)); + opkernel->op_infer_ctx_for_scheduler_thread_.reset(new LocalUserOpInferContext( + user_op_conf, opkernel->composed_attrs_for_scheduler_thread_.get(), input_arg_tuple, + output_arg_tuple)); + opkernel->op_infer_ctx_for_main_thread_.reset( + new LocalUserOpInferContext(user_op_conf, opkernel->composed_attrs_for_main_thread_.get(), + input_arg_tuple, output_arg_tuple)); opkernel->compute_ctx_.reset(new LocalUserKernelComputeContext( - nullptr, device_tag, user_op_conf, composed_attrs, input_arg_tuple, output_arg_tuple, - opkernel->mut_temp_blob_object())); - opkernel->create_ctx_.reset(new LocalUserKernelCreateContext(user_op_conf, composed_attrs)); - opkernel->reg_ctx_.reset(new LocalUserKernelRegContext(device_tag, user_op_conf, composed_attrs, - input_arg_tuple, output_arg_tuple)); + nullptr, device_tag, user_op_conf, opkernel->composed_attrs_for_scheduler_thread_.get(), + input_arg_tuple, output_arg_tuple, opkernel->mut_temp_blob_object())); + opkernel->create_ctx_.reset(new LocalUserKernelCreateContext( + user_op_conf, opkernel->composed_attrs_for_scheduler_thread_.get())); + opkernel->reg_ctx_.reset(new LocalUserKernelRegContext( + device_tag, user_op_conf, opkernel->composed_attrs_for_scheduler_thread_.get(), + input_arg_tuple, output_arg_tuple)); const auto* op_reg_val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name()); CHECK_NOTNULL_OR_RETURN(op_reg_val); @@ -465,8 +469,5 @@ LocalUserKernelComputeContext* StatefulLocalOpKernel::UpdateComputeContext( return compute_ctx_.get(); } -void StatefulLocalOpKernel::ResetDynamicOpAttrs(const AttrMap& attrs) { - composed_attrs_->ResetPrior(attrs); -} } // namespace one } // namespace oneflow diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h index 011288a4ec8..3157e539ec6 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_local_opkernel.h @@ -286,14 +286,20 @@ class StatefulLocalOpKernel final { const EagerBlobObjectListPtr& outputs, LocalUserOpInferContext* op_infer_ctx); - void ResetDynamicOpAttrs(const AttrMap& attrs); + ComposedAttrMap* composed_attrs_for_scheduler_thread() const { + return composed_attrs_for_scheduler_thread_.get(); + } + + ComposedAttrMap* composed_attrs_for_main_thread() const { + return composed_attrs_for_main_thread_.get(); + } - LocalUserOpInferContext* op_infer_ctx_for_thread_a() const { - return op_infer_ctx_for_thread_a_.get(); + LocalUserOpInferContext* op_infer_ctx_for_scheduler_thread() const { + return op_infer_ctx_for_scheduler_thread_.get(); } - LocalUserOpInferContext* op_infer_ctx_for_thread_b() const { - return op_infer_ctx_for_thread_b_.get(); + LocalUserOpInferContext* op_infer_ctx_for_main_thread() const { + return op_infer_ctx_for_main_thread_.get(); } void set_need_check_mem_case(bool value) { need_check_mem_case_ = value; } @@ -326,13 +332,14 @@ class StatefulLocalOpKernel final { const user_op::InferTmpSizeFn& GetInferTmpSizeFn(const user_op::OpKernel* op_kernel) const; std::shared_ptr op_conf_; - std::unique_ptr composed_attrs_; + std::unique_ptr composed_attrs_for_scheduler_thread_; + std::unique_ptr composed_attrs_for_main_thread_; std::unique_ptr user_op_conf_; std::shared_ptr device_; std::unique_ptr reg_ctx_; std::unique_ptr create_ctx_; - std::unique_ptr op_infer_ctx_for_thread_a_; - std::unique_ptr op_infer_ctx_for_thread_b_; + std::unique_ptr op_infer_ctx_for_scheduler_thread_; + std::unique_ptr op_infer_ctx_for_main_thread_; std::unique_ptr compute_ctx_; std::shared_ptr input_arg_tuple_; std::shared_ptr output_arg_tuple_;