Skip to content

Commit

Permalink
Refactor Memory Zone (#5072)
Browse files Browse the repository at this point in the history
* MemZoneId


Former-commit-id: 7550a12

* move mem zone id source code


Former-commit-id: 3859fc2

* revert


Former-commit-id: 5cf3ad7

* refine GetProxyNode using MemZoneId


Former-commit-id: fba035f

* refactor MemZoneId121


Former-commit-id: 0868a61

* replace using IDMgr interface


Former-commit-id: 98b5db9

* fix linkage

* rm useless comment

* replace IsGpuMemZone

* format

* rm deprecated mem zone api in IDMgr

* fix merge conflict error

* refine mem zone id to include node index

* revert added header

* direct init device_id

* address review

* address review

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
leaves-zwx and oneflow-ci-bot committed Jun 21, 2021
1 parent 5303ee1 commit 50f32b6
Show file tree
Hide file tree
Showing 12 changed files with 234 additions and 121 deletions.
37 changes: 18 additions & 19 deletions oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
UNIMPLEMENTED();
}
dst_node->Init(lbi, dst_slice, kSliceBoxingTaskModeCopy, src_node->machine_id(), thrd_id,
Global<IDMgr>::Get()->CpuMemZoneId());
GetNodeCPUMemZoneId(src_node->machine_id()));
dst_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), src_slice);
return dst_node;
};
Expand All @@ -184,7 +184,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
out_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);
} else {
TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(
in_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
in_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice);
}
}
Expand Down Expand Up @@ -283,15 +283,15 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
#endif
}
local_concat_node->Init(lbi, concat_slice, kSliceBoxingTaskModeCopy, in_machine_id,
local_concat_thrd_id, Global<IDMgr>::Get()->CpuMemZoneId());
local_concat_thrd_id, GetNodeCPUMemZoneId(in_machine_id));
for (const int64_t in_id : in_parallel_ids) {
if (!in_id2intersection.at(in_id).IsEmpty()) {
local_concat_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(),
in_slices.at(in_id));
}
}
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_concat_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
local_concat_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), concat_slice);
}
}
Expand Down Expand Up @@ -345,13 +345,12 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
#endif
}
local_add_node->Init(lbi, out_slice, kSliceBoxingTaskModeAdd, in_machine_id,
local_add_thrd_id, Global<IDMgr>::Get()->CpuMemZoneId());
local_add_thrd_id, GetNodeCPUMemZoneId(in_machine_id));
for (const int64_t in_id : in_parallel_ids) {
local_add_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), in_slice);
}
TaskNode* local_add_proxy_node =
ctx->task_graph()->GetProxyNode(local_add_node, lbi, out_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_add_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), out_slice);
}
}
Expand Down Expand Up @@ -405,33 +404,33 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
const int64_t out_machine_id = machine_id7out_parallel_ids.first;
TaskNode* in_box_node = nullptr;
if (out_box_nodes.size() == 1) {
in_box_node = ctx->task_graph()->GetProxyNode(out_box_nodes.front(), lbi,
machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->CpuMemZoneId());
in_box_node = ctx->task_graph()->GetProxyNode(
out_box_nodes.front(), lbi, GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first));
} else {
auto* add_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();
add_node->Init(lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first),
Global<IDMgr>::Get()->CpuMemZoneId());
GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first));
for (TaskNode* out_box_node : out_box_nodes) {
TaskNode* out_boxing_node_proxy = ctx->task_graph()->GetProxyNode(
out_box_node, lbi, out_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
out_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id));
add_node->ConnectToSrcNodeWithSlice(out_boxing_node_proxy, NewEdge(), slice);
}
in_box_node = add_node;
}
for (const int64_t out_id : machine_id7out_parallel_ids.second) {
int64_t mem_zone_id;
if (out_pd.device_type() == DeviceType::kCPU) {
mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
(*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode(
in_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id));
} else if (out_pd.device_type() == DeviceType::kGPU) {
mem_zone_id =
Global<IDMgr>::Get()->GpuMemZoneId(CHECK_JUST(out_pd.DeviceId4ParallelId(out_id)));
int64_t dev_id = CHECK_JUST(out_pd.DeviceId4ParallelId(out_id));
(*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode(
in_box_node, lbi,
MemZoneId{static_cast<MemZoneId::node_index_t>(out_machine_id), DeviceType::kGPU,
static_cast<MemZoneId::device_index_t>(dev_id)});
} else {
UNIMPLEMENTED();
}
(*out_nodes)[out_id] =
ctx->task_graph()->GetProxyNode(in_box_node, lbi, out_machine_id, mem_zone_id);
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/graph/copy_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class CopyHdTaskNode final : public CopyTaskNode {
void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi);

CopyHdOpConf::Type copy_type() const { return copy_type_; }
int64_t MemZoneId121() const override {
MemZoneId MemZoneId121() const override {
if (copy_type_ == CopyHdOpConf::H2D) {
return TaskNode::MemZoneId121();
} else if (copy_type_ == CopyHdOpConf::D2H) {
return Global<IDMgr>::Get()->CpuMemZoneId();
return GetNodeCPUMemZoneId(this->machine_id());
} else {
UNIMPLEMENTED();
return -1;
}
return kInvalidMemZoneId;
}

private:
Expand Down
32 changes: 13 additions & 19 deletions oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"
#include "oneflow/core/graph/id_serialization.h"

namespace oneflow {

void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
const SliceBoxingTaskMode mode, int64_t machine_id, int64_t thrd_id,
int64_t mem_zone_id) {
MemZoneId&& mem_zone_id) {
out_slice_ = out_slice;
out_shape_ = out_slice.shape();
mode_ = mode;
mem_zone_id_ = mem_zone_id;
mem_zone_id_ = std::move(mem_zone_id);
set_machine_id(machine_id);
set_thrd_id(thrd_id);
set_lbi(lbi);
Expand All @@ -33,17 +34,8 @@ void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView&
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
const SliceBoxingTaskMode mode, int64_t machine_id,
int64_t thrd_id) {
IDMgr* global_id_mgr = Global<IDMgr>::Get();
DeviceType device_type = global_id_mgr->GetDeviceTypeFromThrdId(thrd_id);
int64_t mem_zone_id;
if (device_type == DeviceType::kCPU) {
mem_zone_id = global_id_mgr->CpuMemZoneId();
} else if (device_type == DeviceType::kGPU) {
mem_zone_id = global_id_mgr->GpuMemZoneId(global_id_mgr->GetGpuPhyIdFromThrdId(thrd_id));
} else {
UNIMPLEMENTED();
}
Init(lbi, out_slice, mode, machine_id, thrd_id, mem_zone_id);
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id);
Init(lbi, out_slice, mode, machine_id, thrd_id, MemZoneId(stream_id.device_id()));
}

void SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() {
Expand Down Expand Up @@ -126,17 +118,19 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {
}

void SliceBoxingTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
if (Global<IDMgr>::Get()->IsCpuMemZone(mem_zone_id_)) {
if (mem_zone_id_.device_type() == DeviceType::kCPU) {
HostMemory* host_mem = mem_case->mutable_host_mem();
if (device_type() == DeviceType::kGPU) {
host_mem->mutable_cuda_pinned_mem()->set_device_id(GpuPhyId());
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id());
if (stream_id.device_id().device_type() == DeviceType::kGPU) {
host_mem->mutable_cuda_pinned_mem()->set_device_id(stream_id.device_id().device_index());
}
} else if (Global<IDMgr>::Get()->IsGpuMemZone(mem_zone_id_)) {
mem_case->mutable_device_cuda_mem()->set_device_id(
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(mem_zone_id_));
} else if (mem_zone_id_.device_type() == DeviceType::kGPU) {
mem_case->mutable_device_cuda_mem()->set_device_id(mem_zone_id_.device_index());
} else {
UNIMPLEMENTED();
}
}

MemZoneId SliceBoxingTaskNode::MemZoneId121() const { return mem_zone_id_; }

} // namespace oneflow
7 changes: 4 additions & 3 deletions oneflow/core/graph/slice_boxing_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "oneflow/core/graph/transport_task_node.h"
#include "oneflow/core/register/tensor_slice_view.h"
#include "oneflow/core/memory/memory_zone.h"

namespace oneflow {

Expand All @@ -34,7 +35,7 @@ class SliceBoxingTaskNode final : public TransportTaskNode {
~SliceBoxingTaskNode() override = default;

void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode,
int64_t machine_id, int64_t thrd_id, int64_t mem_zone_id);
int64_t machine_id, int64_t thrd_id, MemZoneId&& mem_zone_id);
void Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice, SliceBoxingTaskMode mode,
int64_t machine_id, int64_t thrd_id);
void ProduceAllRegstsAndBindEdges() override;
Expand All @@ -49,14 +50,14 @@ class SliceBoxingTaskNode final : public TransportTaskNode {
void InferProducedDataRegstTimeShape() override;
OperatorConf GetBoxingOpConf();
void InitProducedRegstMemCase(MemoryCase*) override;
int64_t MemZoneId121() const override { return mem_zone_id_; }
MemZoneId MemZoneId121() const override;

HashMap<const TaskEdge*, TensorSliceView> in_data_edge2slice_;
std::vector<const TaskEdge*> ordered_in_data_edges_;
TensorSliceView out_slice_;
Shape out_shape_;
SliceBoxingTaskMode mode_ = kSliceBoxingTaskModeInvalid;
int64_t mem_zone_id_;
MemZoneId mem_zone_id_;
};

} // namespace oneflow
Expand Down
90 changes: 43 additions & 47 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,69 +464,67 @@ TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis)
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
int64_t dst_machine_id, int64_t dst_mem_zone_id) {
int64_t src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id);
if (proxy2node.find(key) != proxy2node.cend()) {
return proxy2node.at(key);
const MemZoneId& dst_mem_zone_id) {
const auto& src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_mem_zone_id);
auto it = proxy2node.find(key);
if (it != proxy2node.cend()) {
// hit cache
return it->second;
} else {
if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) {
if (src_mem_zone_id == dst_mem_zone_id) {
// in the same memory zone
proxy2node[key] = src_node;
return src_node;
} else if (Global<IDMgr>::Get()->IsGpuMemZone(dst_mem_zone_id)) {
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, dst_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else if (Global<IDMgr>::Get()->IsCpuMemZone(dst_mem_zone_id)) {
if (src_node->machine_id() == dst_machine_id) {
if (Global<IDMgr>::Get()->IsGpuMemZone(src_mem_zone_id)) {
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
UNIMPLEMENTED();
}
} else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) {
if (src_mem_zone_id.node_index() == dst_mem_zone_id.node_index()) {
CHECK_EQ(src_mem_zone_id.device_type(), DeviceType::kGPU);
// on the same node, not on the same device
// src must be not on the cpu mem zone, copy d2h first
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_mem_zone_id.node_index(),
src_mem_zone_id.device_index(), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
// not on the same node, need CopyCommNet from src to dst
// build src cpu proxy first
TaskNode* proxy_on_src_host =
GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.node_index()));
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id, lbi);
copy_comm_net_task->Init(dst_mem_zone_id.node_index(), lbi);
Connect<TaskNode>(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task);
proxy2node[key] = copy_comm_net_task;
return copy_comm_net_task;
}
} else {
UNIMPLEMENTED();
CHECK_EQ(dst_mem_zone_id.device_type(), DeviceType::kGPU);
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.node_index()));
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
dst_mem_zone_id.device_index(), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
}
}
UNIMPLEMENTED();
return nullptr;
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) {
const int64_t dst_machine_id =
CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));
int64_t dst_mem_zone_id;
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (dst_parallel_desc.device_type() == DeviceType::kCPU) {
dst_mem_zone_id = id_mgr->CpuMemZoneId();
} else if (dst_parallel_desc.device_type() == DeviceType::kGPU) {
const int64_t dst_dev_phy_id =
CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id);
} else {
UNIMPLEMENTED();
}
return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
DeviceType device_type = dst_parallel_desc.device_type();
auto device_index =
(device_type == DeviceType::kCPU ? DeviceId::kCPUDeviceIndex
: static_cast<DeviceId::device_index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::node_index_t>(dst_machine_id), device_type,
device_index};
return GetProxyNode(src_node, lbi, mem_zone_id);
}

void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
Expand Down Expand Up @@ -848,9 +846,7 @@ void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const Log
}

void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {
int64_t dst_machine_id = dst_node->machine_id();
int64_t dst_mem_zone_id = dst_node->MemZoneId121();
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_node->MemZoneId121());
ConnectWithLbi(proxy_node, dst_node, lbi);
}

Expand Down
16 changes: 8 additions & 8 deletions oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/register/op_blob_arg_info.h"
#include "oneflow/core/graph/boxing/boxing_logger.h"
#include "oneflow/core/memory/memory_zone.h"

namespace oneflow {

Expand Down Expand Up @@ -51,8 +52,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id,
int64_t dst_mem_zone_id);
TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const MemZoneId& dst_mem_zone_id);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);
Expand Down Expand Up @@ -104,21 +105,20 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
struct ProxyKey {
TaskNode* src_node;
LogicalBlobId lbi;
int64_t dst_machine_id;
int64_t dst_mem_zone_id;
MemZoneId dst_mem_zone_id;

ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, int64_t arg_machine, int64_t arg_zone)
: src_node(src), lbi(arg_lbi), dst_machine_id(arg_machine), dst_mem_zone_id(arg_zone) {}
ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, const MemZoneId& arg_mem_zone_id)
: src_node(src), lbi(arg_lbi), dst_mem_zone_id(arg_mem_zone_id) {}

bool operator==(const ProxyKey& other) const {
return src_node == other.src_node && lbi == other.lbi
&& dst_machine_id == other.dst_machine_id && dst_mem_zone_id == other.dst_mem_zone_id;
&& dst_mem_zone_id == other.dst_mem_zone_id;
}

struct Hasher {
inline size_t operator()(const ProxyKey& key) const {
return std::hash<TaskNode*>{}(key.src_node) ^ std::hash<LogicalBlobId>{}(key.lbi)
^ key.dst_machine_id ^ key.dst_mem_zone_id;
^ key.dst_mem_zone_id.hash();
}
};
};
Expand Down

0 comments on commit 50f32b6

Please sign in to comment.