Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Memory Zone #5072

Merged
merged 22 commits into from
Jun 21, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions oneflow/core/graph/boxing/slice_boxing_sub_task_graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
UNIMPLEMENTED();
}
dst_node->Init(lbi, dst_slice, kSliceBoxingTaskModeCopy, src_node->machine_id(), thrd_id,
Global<IDMgr>::Get()->CpuMemZoneId());
EncodeMemZoneIdToInt64(GetNodeCPUMemZoneId(src_node->machine_id())));
dst_node->ConnectToSrcNodeWithSlice(src_node, NewEdge(), src_slice);
return dst_node;
};
Expand All @@ -184,7 +184,7 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
out_node->ConnectToSrcNodeWithSlice(in_node, NewEdge(), in_slice);
} else {
TaskNode* proxy_node = ctx->task_graph()->GetProxyNode(
in_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
in_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(proxy_node, NewEdge(), in_slice);
}
}
Expand Down Expand Up @@ -283,15 +283,16 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
#endif
}
local_concat_node->Init(lbi, concat_slice, kSliceBoxingTaskModeCopy, in_machine_id,
local_concat_thrd_id, Global<IDMgr>::Get()->CpuMemZoneId());
local_concat_thrd_id,
EncodeMemZoneIdToInt64(GetNodeCPUMemZoneId(in_machine_id)));
for (const int64_t in_id : in_parallel_ids) {
if (!in_id2intersection.at(in_id).IsEmpty()) {
local_concat_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(),
in_slices.at(in_id));
}
}
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_concat_node, lbi, out_node->machine_id(), Global<IDMgr>::Get()->CpuMemZoneId());
local_concat_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), concat_slice);
}
}
Expand Down Expand Up @@ -345,13 +346,13 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
#endif
}
local_add_node->Init(lbi, out_slice, kSliceBoxingTaskModeAdd, in_machine_id,
local_add_thrd_id, Global<IDMgr>::Get()->CpuMemZoneId());
local_add_thrd_id,
EncodeMemZoneIdToInt64(GetNodeCPUMemZoneId(in_machine_id)));
for (const int64_t in_id : in_parallel_ids) {
local_add_node->ConnectToSrcNodeWithSlice(in_nodes.at(in_id), NewEdge(), in_slice);
}
TaskNode* local_add_proxy_node =
ctx->task_graph()->GetProxyNode(local_add_node, lbi, out_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
TaskNode* local_add_proxy_node = ctx->task_graph()->GetProxyNode(
local_add_node, lbi, GetNodeCPUMemZoneId(out_node->machine_id()));
out_node->ConnectToSrcNodeWithSlice(local_add_proxy_node, NewEdge(), out_slice);
}
}
Expand Down Expand Up @@ -405,33 +406,34 @@ Maybe<SubTskGphBuilderStatus> SliceBoxingSubTskGphBuilder::Build(
const int64_t out_machine_id = machine_id7out_parallel_ids.first;
TaskNode* in_box_node = nullptr;
if (out_box_nodes.size() == 1) {
in_box_node = ctx->task_graph()->GetProxyNode(out_box_nodes.front(), lbi,
machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->CpuMemZoneId());
in_box_node = ctx->task_graph()->GetProxyNode(
out_box_nodes.front(), lbi, GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first));
} else {
auto* add_node = ctx->task_graph()->NewNode<SliceBoxingTaskNode>();
add_node->Init(lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first),
Global<IDMgr>::Get()->CpuMemZoneId());
add_node->Init(
lbi, slice, kSliceBoxingTaskModeAdd, machine_id7out_parallel_ids.first,
Global<IDMgr>::Get()->PickCpuThrdIdEvenly(machine_id7out_parallel_ids.first),
EncodeMemZoneIdToInt64(GetNodeCPUMemZoneId(machine_id7out_parallel_ids.first)));
for (TaskNode* out_box_node : out_box_nodes) {
TaskNode* out_boxing_node_proxy = ctx->task_graph()->GetProxyNode(
out_box_node, lbi, out_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
out_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id));
add_node->ConnectToSrcNodeWithSlice(out_boxing_node_proxy, NewEdge(), slice);
}
in_box_node = add_node;
}
for (const int64_t out_id : machine_id7out_parallel_ids.second) {
int64_t mem_zone_id;
if (out_pd.device_type() == DeviceType::kCPU) {
mem_zone_id = Global<IDMgr>::Get()->CpuMemZoneId();
(*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode(
in_box_node, lbi, GetNodeCPUMemZoneId(out_machine_id));
} else if (out_pd.device_type() == DeviceType::kGPU) {
mem_zone_id =
Global<IDMgr>::Get()->GpuMemZoneId(CHECK_JUST(out_pd.DeviceId4ParallelId(out_id)));
int64_t dev_id = CHECK_JUST(out_pd.DeviceId4ParallelId(out_id));
(*out_nodes)[out_id] = ctx->task_graph()->GetProxyNode(
in_box_node, lbi,
MemZoneId{static_cast<MemZoneId::node_index_t>(out_machine_id), DeviceType::kGPU,
static_cast<MemZoneId::device_index_t>(dev_id)});
} else {
UNIMPLEMENTED();
}
(*out_nodes)[out_id] =
ctx->task_graph()->GetProxyNode(in_box_node, lbi, out_machine_id, mem_zone_id);
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/graph/copy_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class CopyHdTaskNode final : public CopyTaskNode {
void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi);

CopyHdOpConf::Type copy_type() const { return copy_type_; }
int64_t MemZoneId121() const override {
MemZoneId MemZoneId121() const override {
if (copy_type_ == CopyHdOpConf::H2D) {
return TaskNode::MemZoneId121();
} else if (copy_type_ == CopyHdOpConf::D2H) {
return Global<IDMgr>::Get()->CpuMemZoneId();
return GetNodeCPUMemZoneId(this->machine_id());
} else {
UNIMPLEMENTED();
return -1;
}
return kInvalidMemZoneId;
}

private:
Expand Down
32 changes: 15 additions & 17 deletions oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"
#include "oneflow/core/graph/id_serialization.h"

namespace oneflow {

Expand All @@ -33,17 +34,9 @@ void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView&
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
const SliceBoxingTaskMode mode, int64_t machine_id,
int64_t thrd_id) {
IDMgr* global_id_mgr = Global<IDMgr>::Get();
DeviceType device_type = global_id_mgr->GetDeviceTypeFromThrdId(thrd_id);
int64_t mem_zone_id;
if (device_type == DeviceType::kCPU) {
mem_zone_id = global_id_mgr->CpuMemZoneId();
} else if (device_type == DeviceType::kGPU) {
mem_zone_id = global_id_mgr->GpuMemZoneId(global_id_mgr->GetGpuPhyIdFromThrdId(thrd_id));
} else {
UNIMPLEMENTED();
}
Init(lbi, out_slice, mode, machine_id, thrd_id, mem_zone_id);
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id);
MemZoneId mem_zone_id{stream_id.device_id()};
Init(lbi, out_slice, mode, machine_id, thrd_id, EncodeMemZoneIdToInt64(mem_zone_id));
}

void SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() {
Expand Down Expand Up @@ -126,17 +119,22 @@ OperatorConf SliceBoxingTaskNode::GetBoxingOpConf() {
}

void SliceBoxingTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
if (Global<IDMgr>::Get()->IsCpuMemZone(mem_zone_id_)) {
auto mem_zone_id = DecodeMemZoneIdFromInt64(mem_zone_id_);
if (mem_zone_id.device_type() == DeviceType::kCPU) {
HostMemory* host_mem = mem_case->mutable_host_mem();
if (device_type() == DeviceType::kGPU) {
host_mem->mutable_cuda_pinned_mem()->set_device_id(GpuPhyId());
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id());
if (stream_id.device_id().device_type() == DeviceType::kGPU) {
host_mem->mutable_cuda_pinned_mem()->set_device_id(stream_id.device_id().device_index());
}
} else if (Global<IDMgr>::Get()->IsGpuMemZone(mem_zone_id_)) {
mem_case->mutable_device_cuda_mem()->set_device_id(
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(mem_zone_id_));
} else if (mem_zone_id.device_type() == DeviceType::kGPU) {
mem_case->mutable_device_cuda_mem()->set_device_id(mem_zone_id.device_index());
} else {
UNIMPLEMENTED();
}
}

MemZoneId SliceBoxingTaskNode::MemZoneId121() const {
return DecodeMemZoneIdFromInt64(mem_zone_id_);
leaves-zwx marked this conversation as resolved.
Show resolved Hide resolved
}

} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/core/graph/slice_boxing_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SliceBoxingTaskNode final : public TransportTaskNode {
void InferProducedDataRegstTimeShape() override;
OperatorConf GetBoxingOpConf();
void InitProducedRegstMemCase(MemoryCase*) override;
int64_t MemZoneId121() const override { return mem_zone_id_; }
MemZoneId MemZoneId121() const override;

HashMap<const TaskEdge*, TensorSliceView> in_data_edge2slice_;
std::vector<const TaskEdge*> ordered_in_data_edges_;
Expand Down
84 changes: 39 additions & 45 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,69 +464,65 @@ TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis)
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
int64_t dst_machine_id, int64_t dst_mem_zone_id) {
int64_t src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id);
const MemZoneId& dst_mem_zone_id) {
const auto& src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_mem_zone_id);
if (proxy2node.find(key) != proxy2node.cend()) {
// hit cache
leaves-zwx marked this conversation as resolved.
Show resolved Hide resolved
return proxy2node.at(key);
} else {
if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) {
if (src_mem_zone_id == dst_mem_zone_id) {
// in the same memory zone
proxy2node[key] = src_node;
return src_node;
} else if (Global<IDMgr>::Get()->IsGpuMemZone(dst_mem_zone_id)) {
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, dst_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else if (Global<IDMgr>::Get()->IsCpuMemZone(dst_mem_zone_id)) {
if (src_node->machine_id() == dst_machine_id) {
if (Global<IDMgr>::Get()->IsGpuMemZone(src_mem_zone_id)) {
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
UNIMPLEMENTED();
}
} else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) {
if (src_mem_zone_id.node_index() == dst_mem_zone_id.node_index()) {
// on the same node, not on the same device
// src must be not on the cpu mem zone, copy d2h first
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
leaves-zwx marked this conversation as resolved.
Show resolved Hide resolved
copy_task->Init(CopyHdOpConf::D2H, src_mem_zone_id.node_index(),
src_mem_zone_id.device_index(), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
// not on the same node, need CopyCommNet from src to dst
// build src cpu proxy first
TaskNode* proxy_on_src_host =
GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(src_mem_zone_id.node_index()));
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id, lbi);
copy_comm_net_task->Init(dst_mem_zone_id.node_index(), lbi);
Connect<TaskNode>(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task);
proxy2node[key] = copy_comm_net_task;
return copy_comm_net_task;
}
} else {
UNIMPLEMENTED();
CHECK_EQ(dst_mem_zone_id.device_type(), DeviceType::kGPU);
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, GetNodeCPUMemZoneId(dst_mem_zone_id.node_index()));
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
dst_mem_zone_id.device_index(), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
}
}
UNIMPLEMENTED();
return nullptr;
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) {
const int64_t dst_machine_id =
CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));
int64_t dst_mem_zone_id;
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (dst_parallel_desc.device_type() == DeviceType::kCPU) {
dst_mem_zone_id = id_mgr->CpuMemZoneId();
} else if (dst_parallel_desc.device_type() == DeviceType::kGPU) {
const int64_t dst_dev_phy_id =
CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id);
} else {
UNIMPLEMENTED();
}
return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
const int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
DeviceType device_type = dst_parallel_desc.device_type();
auto device_index =
(device_type == DeviceType::kCPU ? DeviceId::kCPUDeviceIndex
: static_cast<DeviceId::device_index_t>(dev_id));
MemZoneId mem_zone_id{static_cast<MemZoneId::node_index_t>(dst_machine_id), device_type,
device_index};
return GetProxyNode(src_node, lbi, mem_zone_id);
}

void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
Expand Down Expand Up @@ -848,9 +844,7 @@ void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const Log
}

void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {
int64_t dst_machine_id = dst_node->machine_id();
int64_t dst_mem_zone_id = dst_node->MemZoneId121();
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_node->MemZoneId121());
ConnectWithLbi(proxy_node, dst_node, lbi);
}

Expand Down
17 changes: 8 additions & 9 deletions oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/register/op_blob_arg_info.h"
#include "oneflow/core/graph/boxing/boxing_logger.h"
#include "oneflow/core/memory/memory_zone.h"

namespace oneflow {

Expand Down Expand Up @@ -51,8 +52,8 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
void EnableInplaceMemSharing(const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id,
int64_t dst_mem_zone_id);
TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const MemZoneId& dst_mem_zone_id);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);
Expand Down Expand Up @@ -104,21 +105,19 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
struct ProxyKey {
TaskNode* src_node;
LogicalBlobId lbi;
int64_t dst_machine_id;
int64_t dst_mem_zone_id;
MemZoneId mem_zone_id;
leaves-zwx marked this conversation as resolved.
Show resolved Hide resolved

ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, int64_t arg_machine, int64_t arg_zone)
: src_node(src), lbi(arg_lbi), dst_machine_id(arg_machine), dst_mem_zone_id(arg_zone) {}
ProxyKey(TaskNode* src, const LogicalBlobId& arg_lbi, const MemZoneId& arg_mem_zone_id)
: src_node(src), lbi(arg_lbi), mem_zone_id(arg_mem_zone_id) {}

bool operator==(const ProxyKey& other) const {
return src_node == other.src_node && lbi == other.lbi
&& dst_machine_id == other.dst_machine_id && dst_mem_zone_id == other.dst_mem_zone_id;
return src_node == other.src_node && lbi == other.lbi && mem_zone_id == other.mem_zone_id;
}

struct Hasher {
inline size_t operator()(const ProxyKey& key) const {
return std::hash<TaskNode*>{}(key.src_node) ^ std::hash<LogicalBlobId>{}(key.lbi)
^ key.dst_machine_id ^ key.dst_mem_zone_id;
^ key.mem_zone_id.hash();
}
};
};
Expand Down
10 changes: 3 additions & 7 deletions oneflow/core/graph/task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,9 @@ void TaskNode::ToProto(TaskProto* task_proto) const {
}
}

int64_t TaskNode::MemZoneId121() const {
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (device_type() == DeviceType::kCPU) {
return id_mgr->CpuMemZoneId();
} else {
return id_mgr->GpuMemZoneId(id_mgr->GetGpuPhyIdFromThrdId(thrd_id_));
}
MemZoneId TaskNode::MemZoneId121() const {
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id_);
return MemZoneId{stream_id.device_id()};
}

bool TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name) {
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/graph/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "oneflow/core/job/task.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/common/auto_registration_factory.h"
#include "oneflow/core/memory/memory_zone.h"

namespace std {

Expand Down Expand Up @@ -97,7 +98,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual void ToProto(TaskProto*) const;
virtual bool IsIndependent() const { return false; }
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual int64_t MemZoneId121() const;
virtual MemZoneId MemZoneId121() const;
bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name);
RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node);
RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name);
Expand Down
Loading