Skip to content

Commit

Permalink
refine GetProxyNode using MemZoneId
Browse files Browse the repository at this point in the history
  • Loading branch information
leaves-zwx committed May 28, 2021
1 parent 5cf3ad7 commit fba035f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
69 changes: 33 additions & 36 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,69 +462,66 @@ TaskEdge* TaskGraph::NewTaskEdgeWithLbis(const std::vector<LogicalBlobId>& lbis)
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
int64_t dst_machine_id, int64_t dst_mem_zone_id) {
int64_t src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_machine_id, dst_mem_zone_id);
int64_t dst_machine_id, const MemZoneId& dst_mem_zone_id) {
const auto& src_mem_zone_id = src_node->MemZoneId121();
const ProxyKey key(src_node, lbi, dst_machine_id, EncodeMemZoneIdToInt64(dst_mem_zone_id));
if (proxy2node.find(key) != proxy2node.cend()) {
// hit cache
return proxy2node.at(key);
} else {
if (dst_machine_id == src_node->machine_id() && dst_mem_zone_id == src_mem_zone_id) {
// in the same memory zone
proxy2node[key] = src_node;
return src_node;
} else if (Global<IDMgr>::Get()->IsGpuMemZone(dst_mem_zone_id)) {
TaskNode* proxy_on_dst_host =
GetProxyNode(src_node, lbi, dst_machine_id, Global<IDMgr>::Get()->CpuMemZoneId());
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(dst_mem_zone_id), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else if (Global<IDMgr>::Get()->IsCpuMemZone(dst_mem_zone_id)) {
} else if (dst_mem_zone_id.device_type() == DeviceType::kCPU) {
if (src_node->machine_id() == dst_machine_id) {
if (Global<IDMgr>::Get()->IsGpuMemZone(src_mem_zone_id)) {
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(),
Global<IDMgr>::Get()->GetGpuPhyIdFromMemZoneId(src_mem_zone_id), lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
UNIMPLEMENTED();
}
// Copy D2H, only support gpu for now
CHECK_EQ(src_mem_zone_id.device_type(), DeviceType::kGPU);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::D2H, src_node->machine_id(), src_mem_zone_id.device_index(),
lbi);
Connect<TaskNode>(src_node, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
} else {
TaskNode* proxy_on_src_host = GetProxyNode(src_node, lbi, src_node->machine_id(),
Global<IDMgr>::Get()->CpuMemZoneId());
// not on the same node, CopyCommNet first
TaskNode* proxy_on_src_host =
GetProxyNode(src_node, lbi, src_node->machine_id(), kCPUMemZoneId);
CopyCommNetTaskNode* copy_comm_net_task = NewNode<CopyCommNetTaskNode>();
copy_comm_net_task->Init(dst_machine_id, lbi);
Connect<TaskNode>(proxy_on_src_host, NewTaskEdgeWithLbi(lbi), copy_comm_net_task);
proxy2node[key] = copy_comm_net_task;
return copy_comm_net_task;
}
} else {
UNIMPLEMENTED();
// Copy H2D, only support gpu for now
CHECK_EQ(dst_mem_zone_id.device_type(), DeviceType::kGPU);
TaskNode* proxy_on_dst_host = GetProxyNode(src_node, lbi, dst_machine_id, kCPUMemZoneId);
CopyHdTaskNode* copy_task = NewNode<CopyHdTaskNode>();
copy_task->Init(CopyHdOpConf::H2D, proxy_on_dst_host->machine_id(),
dst_mem_zone_id.device_index(), lbi);
Connect<TaskNode>(proxy_on_dst_host, NewTaskEdgeWithLbi(lbi), copy_task);
proxy2node[key] = copy_task;
return copy_task;
}
}
UNIMPLEMENTED();
return nullptr;
}

TaskNode* TaskGraph::GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id) {
const int64_t dst_machine_id =
CHECK_JUST(dst_parallel_desc.MachineId4ParallelId(dst_parallel_id));
int64_t dst_mem_zone_id;
const IDMgr* id_mgr = Global<IDMgr>::Get();
if (dst_parallel_desc.device_type() == DeviceType::kCPU) {
dst_mem_zone_id = id_mgr->CpuMemZoneId();
} else if (dst_parallel_desc.device_type() == DeviceType::kGPU) {
const int64_t dst_dev_phy_id =
CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
dst_mem_zone_id = id_mgr->GpuMemZoneId(dst_dev_phy_id);
return GetProxyNode(src_node, lbi, dst_machine_id, kCPUMemZoneId);
} else {
UNIMPLEMENTED();
CHECK_EQ(dst_parallel_desc.device_type(), DeviceType::kGPU);
int64_t dev_id = CHECK_JUST(dst_parallel_desc.DeviceId4ParallelId(dst_parallel_id));
auto device_index = static_cast<MemZoneId::device_index_t>(dev_id);
MemZoneId mem_zone_id{dst_parallel_desc.device_type(), device_index};
return GetProxyNode(src_node, lbi, dst_machine_id, mem_zone_id);
}
return GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
return nullptr;
}

void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/graph/task_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/register/op_blob_arg_info.h"
#include "oneflow/core/graph/boxing/boxing_logger.h"
#include "oneflow/core/memory/memory_zone.h"

namespace oneflow {

Expand Down Expand Up @@ -54,6 +55,9 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi, int64_t dst_machine_id,
int64_t dst_mem_zone_id);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
int64_t dst_machine_id, const MemZoneId& dst_mem_zone_id);

TaskNode* GetProxyNode(TaskNode* src_node, const LogicalBlobId& lbi,
const ParallelDesc& dst_parallel_desc, int64_t dst_parallel_id);

Expand Down

0 comments on commit fba035f

Please sign in to comment.