diff --git a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp index 545b07f5732..f71308ac3c7 100644 --- a/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp @@ -160,7 +160,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( UNIMPLEMENTED(); } dst_node->Init(lbi, dst_slice, kSliceBoxingTaskModeCopy, src_node->machine_id(), thrd_id, - Global::Get()->CpuMemZoneId()); + GetNodeCPUMemZoneId(src_node->machine_id())); dst_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), src_slice); return dst_node; }; @@ -184,7 +184,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( out_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice); } else { TaskNode* proxy_node = ctx->task_graph()->GetProxyNode( - in_node, lbi, out_node->machine_id(), Global::Get()->CpuMemZoneId()); + in_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id())); out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice); } } @@ -283,7 +283,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( #endif } local_concat_node->Init(lbi, concat_slice, kSliceBoxingTaskModeCopy, in_machine_id, - local_concat_thrd_id, Global::Get()->CpuMemZoneId()); + local_concat_thrd_id, GetNodeCPUMemZoneId(in_machine_id)); for (const int64_t in_id : in_parallel_ids) { if (!in_id2intersection.at(in_id).IsEmpty()) { local_concat_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), @@ -291,7 +291,7 @@ Maybe SliceBoxingSubTskGphBuilder::Build( } } TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode( - local_concat_node, lbi, out_node->machine_id(), Global::Get()->CpuMemZoneId()); + local_concat_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id())); out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), concat_slice); } } @@ -345,13 +345,12 @@ Maybe SliceBoxingSubTskGphBuilder::Build( #endif } local_add_node->Init(lbi, out_slice, kSliceBoxingTaskModeAdd, in_machine_id, - local_add_thrd_id, Global::Get()->CpuMemZoneId()); + local_add_thrd_id, GetNodeCPUMemZoneId(in_machine_id)); for (const int64_t in_id : in_parallel_ids) { local_add_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), in_slice); } - TaskNode* local_add_proxy_node = - ctx->task_graph()->GetProxyNode(local_add_node, lbi, out_node->machine_id(), - Global::Get()->CpuMemZoneId()); + TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode( + local_add_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id())); out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), out_slice); } } @@ -405,33 +404,33 @@ Maybe SliceBoxingSubTskGphBuilder::Build( const int64_t out_machine_id = machine_id7out_parallel_ids.first; TaskNode* in_box_node = nullptr; if (out_box_nodes.size() == 1) { - in_box_node = ctx->task_graph()->GetProxyNode(out_box_nodes.front(), lbi, - machine_id7out_parallel_ids.first, - Global::Get()->CpuMemZoneId()); + in_box_node = ctx->task_graph()->GetProxyNode( + out_box_nodes.front(), lbi, GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first)); } else { auto* add_node = ctx->task_graph()->NewNode(); add_node->Init(lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first, Global::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first), - Global::Get()->CpuMemZoneId()); + GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first)); for (TaskNode* out_box_node : out_box_nodes) { TaskNode* out_boxing_node_proxy = ctx->task_graph()->GetProxyNode( - out_box_node, lbi, out_machine_id, Global::Get()->CpuMemZoneId()); + out_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id)); add_node->ConnectToSrcNodeWithSlice(out_boxing_node_proxy, NewEdge(), slice); } in_box_node = add_node; } for (const int64_t out_id : machine_id7out_parallel_ids.second) { - int64_t mem_zone_id; if (out_pd.device_type() == DeviceType::kCPU) { - mem_zone_id = Global::Get()->CpuMemZoneId(); + (*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode( + in_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id)); } else if (out_pd.device_type() == DeviceType::kGPU) { - mem_zone_id = - Global::Get()->GpuMemZoneId(CHECK_JUST(out_pd.DeviceId4ParallelId(out_id))); + int64_t dev_id = CHECK_JUST(out_pd.DeviceId4ParallelId(out_id)); + (*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode( + in_box_node, lbi, + MemZoneId{static_cast(out_machine_id), DeviceType::kGPU, + static_cast(dev_id)}); } else { UNIMPLEMENTED(); } - (*out_nodes)[out_id] = - ctx->task_graph()->GetProxyNode(in_box_node, lbi, out_machine_id, mem_zone_id); } } }; diff --git a/oneflow/core/graph/copy_task_node.h b/oneflow/core/graph/copy_task_node.h index 2b0f5ac71bb..8a503e8e8ab 100644 --- a/oneflow/core/graph/copy_task_node.h +++ b/oneflow/core/graph/copy_task_node.h @@ -48,15 +48,15 @@ class CopyHdTaskNode final : public CopyTaskNode { void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi); CopyHdOpConf::Type copy_type() const { return copy_type_; } - int64_t MemZoneId121() const override { + MemZoneId MemZoneId121() const override { if (copy_type_ == CopyHdOpConf::H2D) { return TaskNode::MemZoneId121(); } else if (copy_type_ == CopyHdOpConf::D2H) { - return Global::Get()->CpuMemZoneId(); + return GetNodeCPUMemZoneId(this->machine_id()); } else { UNIMPLEMENTED(); - return -1; } + return kInvalidMemZoneId; } private: diff --git a/oneflow/core/graph/slice_boxing_task_node.cpp b/oneflow/core/graph/slice_boxing_task_node.cpp index 89ddc7b17f3..d50939ea121 100644 --- a/oneflow/core/graph/slice_boxing_task_node.cpp +++ b/oneflow/core/graph/slice_boxing_task_node.cpp @@ -15,16 +15,17 @@ limitations under the License. */ #include "oneflow/core/framework/to_string.h" #include "oneflow/core/graph/slice_boxing_task_node.h" +#include "oneflow/core/graph/id_serialization.h" namespace oneflow { void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id, - int64_t mem_zone_id) { + MemZoneId&& mem_zone_id) { out_slice_ = out_slice; out_shape_ = out_slice.shape(); mode_ = mode; - mem_zone_id_ = mem_zone_id; + mem_zone_id_ = std::move(mem_zone_id); set_machine_id(machine_id); set_thrd_id(thrd_id); set_lbi(lbi); @@ -33,17 +34,8 @@ void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id) { - IDMgr* global_id_mgr = Global::Get(); - DeviceType device_type = global_id_mgr->GetDeviceTypeFromThrdId(thrd_id); - int64_t mem_zone_id; - if (device_type == DeviceType::kCPU) { - mem_zone_id = global_id_mgr->CpuMemZoneId(); - } else if (device_type == DeviceType::kGPU) { - mem_zone_id = global_id_mgr->GpuMemZoneId(global_id_mgr->GetGpuPhyIdFromThrdId(thrd_id)); - } else { - UNIMPLEMENTED(); - } - Init(lbi, out_slice, mode, machine_id, thrd_id, mem_zone_id); + StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id); + Init(lbi, out_slice, mode, machine_id, thrd_id, MemZoneId(stream_id.device_id())); } void SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() { @@ -126,17 +118,19 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() { } void SliceBoxingTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) { - if (Global::Get()->IsCpuMemZone(mem_zone_id_)) { + if (mem_zone_id_.device_type() == DeviceType::kCPU) { HostMemory* host_mem = mem_case->mutable_host_mem(); - if (device_type() == DeviceType::kGPU) { - host_mem->mutable_cuda_pinned_mem()->set_device_id(GpuPhyId()); + StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id()); + if (stream_id.device_id().device_type() == DeviceType::kGPU) { + host_mem->mutable_cuda_pinned_mem()->set_device_id(stream_id.device_id().device_index()); } - } else if (Global::Get()->IsGpuMemZone(mem_zone_id_)) { - mem_case->mutable_device_cuda_mem()->set_device_id( - Global::Get()->GetGpuPhyIdFromMemZoneId(mem_zone_id_)); + } else if (mem_zone_id_.device_type() == DeviceType::kGPU) { + mem_case->mutable_device_cuda_mem()->set_device_id(mem_zone_id_.device_index()); } else { UNIMPLEMENTED(); } } +MemZoneId SliceBoxingTaskNode::MemZoneId121() const { return mem_zone_id_; } + } // namespace oneflow diff --git a/oneflow/core/graph/slice_boxing_task_node.h b/oneflow/core/graph/slice_boxing_task_node.h index ff72380529a..9e0b92e33f3 100644 --- a/oneflow/core/graph/slice_boxing_task_node.h +++ b/oneflow/core/graph/slice_boxing_task_node.h @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/graph/transport_task_node.h" #include "oneflow/core/register/tensor_slice_view.h" +#include "oneflow/core/memory/memory_zone.h" namespace oneflow { @@ -34,7 +35,7 @@ class SliceBoxingTaskNode final : public TransportTaskNode { ~SliceBoxingTaskNode() override = default; void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode, - int64_t machine_id, int64_t thrd_id, int64_t mem_zone_id); + int64_t machine_id, int64_t thrd_id, MemZoneId&& mem_zone_id); void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id); void ProduceAllRegstsAndBindEdges() override; @@ -49,14 +50,14 @@ class SliceBoxingTaskNode final : public TransportTaskNode { void InferProducedDataRegstTimeShape() override; OperatorConf GetBoxingOpConf(); void InitProducedRegstMemCase(MemoryCase*) override; - int64_t MemZoneId121() const override { return mem_zone_id_; } + MemZoneId MemZoneId121() const override; HashMap in_data_edge2slice_; std::vector ordered_in_data_edges_; TensorSliceView out_slice_; Shape out_shape_; SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid; - int64_t mem_zone_id_; + MemZoneId mem_zone_id_; }; } // namespace oneflow diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index ab02c0b0a3b..74475415694 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -464,50 +464,52 @@ TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector& lbis) } TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, - int64_t dst_machine_id, int64_t dst_mem_zone_id) { - int64_t src_mem_zone_id = src_node->MemZoneId121(); - const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id); - if (proxy2node.find(key) != proxy2node.cend()) { - return proxy2node.at(key); + const MemZoneId& dst_mem_zone_id) { + const auto& src_mem_zone_id = src_node->MemZoneId121(); + const ProxyKey key(src_node, lbi, dst_mem_zone_id); + auto it = proxy2node.find(key); + if (it != proxy2node.cend()) { + // hit cache + return it->second; } else { - if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) { + if (src_mem_zone_id == dst_mem_zone_id) { + // in the same memory zone proxy2node[key] = src_node; return src_node; - } else if (Global::Get()->IsGpuMemZone(dst_mem_zone_id)) { - TaskNode* proxy_on_dst_host = - GetProxyNode(src_node, lbi, dst_machine_id, Global::Get()->CpuMemZoneId()); - CopyHdTaskNode* copy_task = NewNode(); - copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(), - Global::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi); - Connect(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task); - proxy2node[key] = copy_task; - return copy_task; - } else if (Global::Get()->IsCpuMemZone(dst_mem_zone_id)) { - if (src_node->machine_id() == dst_machine_id) { - if (Global::Get()->IsGpuMemZone(src_mem_zone_id)) { - CopyHdTaskNode* copy_task = NewNode(); - copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(), - Global::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi); - Connect(src_node, NewTaskEdgeWithLbi(lbi), copy_task); - proxy2node[key] = copy_task; - return copy_task; - } else { - UNIMPLEMENTED(); - } + } else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) { + if (src_mem_zone_id.node_index() == dst_mem_zone_id.node_index()) { + CHECK_EQ(src_mem_zone_id.device_type(), DeviceType::kGPU); + // on the same node, not on the same device + // src must be not on the cpu mem zone, copy d2h first + CopyHdTaskNode* copy_task = NewNode(); + copy_task->Init(CopyHdOpConf::D2H, src_mem_zone_id.node_index(), + src_mem_zone_id.device_index(), lbi); + Connect(src_node, NewTaskEdgeWithLbi(lbi), copy_task); + proxy2node[key] = copy_task; + return copy_task; } else { - TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(), - Global::Get()->CpuMemZoneId()); + // not on the same node, need CopyCommNet from src to dst + // build src cpu proxy first + TaskNode* proxy_on_src_host = + GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.node_index())); CopyCommNetTaskNode* copy_comm_net_task = NewNode(); - copy_comm_net_task->Init(dst_machine_id, lbi); + copy_comm_net_task->Init(dst_mem_zone_id.node_index(), lbi); Connect(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task); proxy2node[key] = copy_comm_net_task; return copy_comm_net_task; } } else { - UNIMPLEMENTED(); + CHECK_EQ(dst_mem_zone_id.device_type(), DeviceType::kGPU); + TaskNode* proxy_on_dst_host = + GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.node_index())); + CopyHdTaskNode* copy_task = NewNode(); + copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(), + dst_mem_zone_id.device_index(), lbi); + Connect(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task); + proxy2node[key] = copy_task; + return copy_task; } } - UNIMPLEMENTED(); return nullptr; } @@ -515,18 +517,14 @@ TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) { const int64_t dst_machine_id = CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id)); - int64_t dst_mem_zone_id; - const IDMgr* id_mgr = Global::Get(); - if (dst_parallel_desc.device_type() == DeviceType::kCPU) { - dst_mem_zone_id = id_mgr->CpuMemZoneId(); - } else if (dst_parallel_desc.device_type() == DeviceType::kGPU) { - const int64_t dst_dev_phy_id = - CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); - dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id); - } else { - UNIMPLEMENTED(); - } - return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id); + const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id)); + DeviceType device_type = dst_parallel_desc.device_type(); + auto device_index = + (device_type == DeviceType::kCPU ? DeviceId::kCPUDeviceIndex + : static_cast(dev_id)); + MemZoneId mem_zone_id{static_cast(dst_machine_id), device_type, + device_index}; + return GetProxyNode(src_node, lbi, mem_zone_id); } void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_nodes, @@ -848,9 +846,7 @@ void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const Log } void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) { - int64_t dst_machine_id = dst_node->machine_id(); - int64_t dst_mem_zone_id = dst_node->MemZoneId121(); - TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id); + TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_node->MemZoneId121()); ConnectWithLbi(proxy_node, dst_node, lbi); } diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index 11da3468c4d..71593a834f1 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -24,6 +24,7 @@ limitations under the License. #include "oneflow/core/graph/copy_task_node.h" #include "oneflow/core/register/op_blob_arg_info.h" #include "oneflow/core/graph/boxing/boxing_logger.h" +#include "oneflow/core/memory/memory_zone.h" namespace oneflow { @@ -51,8 +52,8 @@ class TaskGraph final : public Graph { void EnableInplaceMemSharing(const std::function& IsOpNameDataOrCtrlReachable); - TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id, - int64_t dst_mem_zone_id); + TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, + const MemZoneId& dst_mem_zone_id); TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id); @@ -104,21 +105,20 @@ class TaskGraph final : public Graph { struct ProxyKey { TaskNode* src_node; LogicalBlobId lbi; - int64_t dst_machine_id; - int64_t dst_mem_zone_id; + MemZoneId dst_mem_zone_id; - ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, int64_t arg_machine, int64_t arg_zone) - : src_node(src), lbi(arg_lbi), dst_machine_id(arg_machine), dst_mem_zone_id(arg_zone) {} + ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, const MemZoneId& arg_mem_zone_id) + : src_node(src), lbi(arg_lbi), dst_mem_zone_id(arg_mem_zone_id) {} bool operator==(const ProxyKey& other) const { return src_node == other.src_node && lbi == other.lbi - && dst_machine_id == other.dst_machine_id && dst_mem_zone_id == other.dst_mem_zone_id; + && dst_mem_zone_id == other.dst_mem_zone_id; } struct Hasher { inline size_t operator()(const ProxyKey& key) const { return std::hash{}(key.src_node) ^ std::hash{}(key.lbi) - ^ key.dst_machine_id ^ key.dst_mem_zone_id; + ^ key.dst_mem_zone_id.hash(); } }; }; diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index 98a1f8d5548..97c181e0eee 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -233,13 +233,9 @@ void TaskNode::ToProto(TaskProto* task_proto) const { } } -int64_t TaskNode::MemZoneId121() const { - const IDMgr* id_mgr = Global::Get(); - if (device_type() == DeviceType::kCPU) { - return id_mgr->CpuMemZoneId(); - } else { - return id_mgr->GpuMemZoneId(id_mgr->GetGpuPhyIdFromThrdId(thrd_id_)); - } +MemZoneId TaskNode::MemZoneId121() const { + StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id_); + return MemZoneId{stream_id.device_id()}; } bool TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name) { diff --git a/oneflow/core/graph/task_node.h b/oneflow/core/graph/task_node.h index 4512f62ccae..33a3958d785 100644 --- a/oneflow/core/graph/task_node.h +++ b/oneflow/core/graph/task_node.h @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/job/task.pb.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/common/auto_registration_factory.h" +#include "oneflow/core/memory/memory_zone.h" namespace std { @@ -97,7 +98,7 @@ class TaskNode : public Node { virtual void ToProto(TaskProto*) const; virtual bool IsIndependent() const { return false; } void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); - virtual int64_t MemZoneId121() const; + virtual MemZoneId MemZoneId121() const; bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name); RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node); RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name); diff --git a/oneflow/core/job/id_manager.h b/oneflow/core/job/id_manager.h index 51356fb33bf..cfb665dc05d 100644 --- a/oneflow/core/job/id_manager.h +++ b/oneflow/core/job/id_manager.h @@ -34,16 +34,6 @@ class IDMgr final { int64_t NewMemBlockId() { return mem_block_id_count_++; } int64_t NewChunkId() { return chunk_id_count_++; } - // MemZoneId - int64_t CpuMemZoneId() const { return gpu_device_num_; } - bool IsCpuMemZone(int64_t mem_zone_id) const { return mem_zone_id == CpuMemZoneId(); } - bool IsGpuMemZone(int64_t mem_zone_id) const { return mem_zone_id < gpu_device_num_; } - int64_t GpuMemZoneId(int64_t dev_phy_id) const { return dev_phy_id; } - int64_t GetGpuPhyIdFromMemZoneId(int64_t mem_zone_id) const { - CHECK_LT(mem_zone_id, gpu_device_num_); - return mem_zone_id; - } - // GetFromThrdId DeviceType GetDeviceTypeFromThrdId(int64_t thrd_id) const; int64_t GetGpuPhyIdFromThrdId(int64_t thrd_id) const; diff --git a/oneflow/core/job/improver.cpp b/oneflow/core/job/improver.cpp index 2467b5b74fb..cc5d6bf3843 100644 --- a/oneflow/core/job/improver.cpp +++ b/oneflow/core/job/improver.cpp @@ -388,15 +388,15 @@ Maybe Improver::CheckAllZoneNotOOM( const uint64_t calc = CalcMemoryConsumed(regst_descs, PathDurations4RegstDescId, PathIIScales4RegstDescId, ii); const uint64_t available = AvailableMemSize(machine_id, mem_zone_id); - const auto* id_mgr = Global::Get(); if (Global::Get()->enable_dry_run()) { + MemZoneId mem_zone = DecodeMemZoneIdFromInt64(mem_zone_id); LOG(ERROR) << "machine_id: " << machine_id << ", mem_zone_id: " << mem_zone_id - << ", is_gpu: " << (id_mgr->IsGpuMemZone(mem_zone_id) ? "yes" : "no") + << ", is_gpu: " << (mem_zone.device_type() == DeviceType::kGPU ? "yes" : "no") << ", CalcMemoryConsumed: " << calc; } if (calc >= available) { - const std::string device_tag = *JUST(DeviceTag4DeviceType( - id_mgr->IsGpuMemZone(mem_zone_id) ? DeviceType::kGPU : DeviceType::kCPU)); + MemZoneId mem_zone = DecodeMemZoneIdFromInt64(mem_zone_id); + const std::string device_tag = *JUST(DeviceTag4DeviceType(mem_zone.device_type())); return Error::MemoryZoneOutOfMemoryError(machine_id, mem_zone_id, calc, available, device_tag) << "OOM detected at compile time. "; diff --git a/oneflow/core/memory/memory_zone.cpp b/oneflow/core/memory/memory_zone.cpp new file mode 100644 index 00000000000..3a6f921f957 --- /dev/null +++ b/oneflow/core/memory/memory_zone.cpp @@ -0,0 +1,56 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/memory/memory_zone.h" +#include "oneflow/core/common/device_type.pb.h" + +namespace oneflow { + +namespace { + +constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits; +constexpr size_t kMemZoneIdNodeIndexShift = kMemZoneIdDeviceTypeShift + MemZoneId::kDeviceTypeBits; + +constexpr int64_t kMemZoneIdNodeIndexInt64Mask = ((int64_t{1} << MemZoneId::kNodeIndexBits) - 1) + << kMemZoneIdNodeIndexShift; +constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1) + << kMemZoneIdDeviceTypeShift; +constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1; + +} // namespace + +const MemZoneId kInvalidMemZoneId = MemZoneId{0, DeviceType::kInvalidDevice, 0}; + +MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index) { + return MemZoneId{node_index, DeviceType::kCPU, MemZoneId::kCPUDeviceIndex}; +} + +int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) { + int64_t id = static_cast(mem_zone_id.device_index()); + id |= static_cast(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift; + id |= static_cast(mem_zone_id.node_index()) << kMemZoneIdNodeIndexShift; + return id; +} + +MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) { + int64_t node_index = (mem_zone_id & kMemZoneIdNodeIndexInt64Mask) >> kMemZoneIdNodeIndexShift; + int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift; + int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask; + return MemZoneId(static_cast(node_index), + static_cast(device_type), + static_cast(device_index)); +} + +} // namespace oneflow diff --git a/oneflow/core/memory/memory_zone.h b/oneflow/core/memory/memory_zone.h new file mode 100644 index 00000000000..1da824faefb --- /dev/null +++ b/oneflow/core/memory/memory_zone.h @@ -0,0 +1,80 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ +#define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/id_util.h" +#include "oneflow/core/common/device_type.pb.h" + +namespace oneflow { + +class MemZoneId { + public: + using node_index_t = DeviceId::rank_t; + using device_index_t = DeviceId::device_index_t; + + constexpr static size_t kNodeIndexBits = DeviceId::kRankBits; + constexpr static size_t kDeviceTypeBits = DeviceId::kDeviceTypeBits; + constexpr static size_t kDeviceIndexBits = DeviceId::kDeviceIndexBits; + + constexpr static size_t kMaxDeviceTypeVal = DeviceId::kMaxDeviceTypeVal; + constexpr static device_index_t kMaxDeviceIndex = DeviceId::kMaxDeviceIndex; + constexpr static device_index_t kCPUDeviceIndex = DeviceId::kCPUDeviceIndex; + + MemZoneId() : device_id_(0, DeviceType::kInvalidDevice, 0) {} + MemZoneId(const DeviceId& device_id) : device_id_(device_id) {} + MemZoneId(DeviceId&& device_id) : device_id_(std::move(device_id)) {} + + MemZoneId(node_index_t node_index, DeviceType device_type, device_index_t device_index) + : device_id_(node_index, device_type, device_index) { + CHECK_LE(static_cast(device_type), kMaxDeviceTypeVal); + CHECK_LE(device_index, kMaxDeviceIndex); + } + + const DeviceId& device_id() const { return device_id_; } + node_index_t node_index() const { return device_id_.rank(); } + DeviceType device_type() const { return device_id_.device_type(); } + device_index_t device_index() const { return device_id_.device_index(); } + + bool operator==(const MemZoneId& rhs) const { return device_id_ == rhs.device_id_; } + bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); } + + size_t hash() const { return device_id_.hash(); } + + private: + DeviceId device_id_; +}; + +int64_t EncodeMemZoneIdToInt64(const MemZoneId&); +MemZoneId DecodeMemZoneIdFromInt64(int64_t); + +MemZoneId GetNodeCPUMemZoneId(MemZoneId::node_index_t node_index); + +extern const MemZoneId kInvalidMemZoneId; + +} // namespace oneflow + +namespace std { + +template<> +struct hash { + size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); } +}; + +} // namespace std + +#endif // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_