Skip to content

Commit

Permalink
refactor MemZoneId121
Browse files Browse the repository at this point in the history
  • Loading branch information
leaves-zwx committed May 31, 2021
1 parent fba035f commit 0868a61
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 21 deletions.
6 changes: 3 additions & 3 deletions oneflow/core/graph/copy_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ class CopyHdTaskNode final : public CopyTaskNode {
void Init(CopyHdOpConf::Type, int64_t machine_id, int64_t dev_phy_id, const LogicalBlobId& lbi);

CopyHdOpConf::Type copy_type() const { return copy_type_; }
int64_t MemZoneId121() const override {
MemZoneId MemZoneId121() const override {
if (copy_type_ == CopyHdOpConf::H2D) {
return TaskNode::MemZoneId121();
} else if (copy_type_ == CopyHdOpConf::D2H) {
return Global<IDMgr>::Get()->CpuMemZoneId();
return kCPUMemZoneId;
} else {
UNIMPLEMENTED();
return -1;
}
return kInvalidMemZoneId;
}

private:
Expand Down
19 changes: 8 additions & 11 deletions oneflow/core/graph/slice_boxing_task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/graph/slice_boxing_task_node.h"
#include "oneflow/core/graph/id_serialization.h"

namespace oneflow {

Expand All @@ -33,17 +34,9 @@ void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView&
void SliceBoxingTaskNode::Init(const LogicalBlobId& lbi, const TensorSliceView& out_slice,
const SliceBoxingTaskMode mode, int64_t machine_id,
int64_t thrd_id) {
IDMgr* global_id_mgr = Global<IDMgr>::Get();
DeviceType device_type = global_id_mgr->GetDeviceTypeFromThrdId(thrd_id);
int64_t mem_zone_id;
if (device_type == DeviceType::kCPU) {
mem_zone_id = global_id_mgr->CpuMemZoneId();
} else if (device_type == DeviceType::kGPU) {
mem_zone_id = global_id_mgr->GpuMemZoneId(global_id_mgr->GetGpuPhyIdFromThrdId(thrd_id));
} else {
UNIMPLEMENTED();
}
Init(lbi, out_slice, mode, machine_id, thrd_id, mem_zone_id);
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id);
MemZoneId mem_zone_id{stream_id.device_id().device_type(), stream_id.device_id().device_index()};
Init(lbi, out_slice, mode, machine_id, thrd_id, EncodeMemZoneIdToInt64(mem_zone_id));
}

void SliceBoxingTaskNode::ProduceAllRegstsAndBindEdges() {
Expand Down Expand Up @@ -139,4 +132,8 @@ void SliceBoxingTaskNode::InitProducedRegstMemCase(MemoryCase* mem_case) {
}
}

MemZoneId SliceBoxingTaskNode::MemZoneId121() const {
return DecodeMemZoneIdFromInt64(mem_zone_id_);
}

} // namespace oneflow
2 changes: 1 addition & 1 deletion oneflow/core/graph/slice_boxing_task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SliceBoxingTaskNode final : public TransportTaskNode {
void InferProducedDataRegstTimeShape() override;
OperatorConf GetBoxingOpConf();
void InitProducedRegstMemCase(MemoryCase*) override;
int64_t MemZoneId121() const override { return mem_zone_id_; }
MemZoneId MemZoneId121() const override;

HashMap<const TaskEdge*, TensorSliceView> in_data_edge2slice_;
std::vector<const TaskEdge*> ordered_in_data_edges_;
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ void TaskGraph::ConnectWithLbi(TaskNode* src_node, TaskNode* dst_node, const Log

void TaskGraph::BuildTaskPath(TaskNode* src_node, TaskNode* dst_node, const LogicalBlobId& lbi) {
int64_t dst_machine_id = dst_node->machine_id();
int64_t dst_mem_zone_id = dst_node->MemZoneId121();
auto dst_mem_zone_id = dst_node->MemZoneId121();
TaskNode* proxy_node = GetProxyNode(src_node, lbi, dst_machine_id, dst_mem_zone_id);
ConnectWithLbi(proxy_node, dst_node, lbi);
}
Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/graph/task_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ void TaskNode::ToProto(TaskProto* task_proto) const {
}
}

int64_t TaskNode::MemZoneId121() const {
const IDMgr* id_mgr = Global<IDMgr>::Get();
MemZoneId TaskNode::MemZoneId121() const {
if (device_type() == DeviceType::kCPU) {
return id_mgr->CpuMemZoneId();
return kCPUMemZoneId;
} else {
return id_mgr->GpuMemZoneId(id_mgr->GetGpuPhyIdFromThrdId(thrd_id_));
StreamId stream_id = DeserializeStreamIdFromInt64(thrd_id_);
return MemZoneId{stream_id.device_id().device_type(), stream_id.device_id().device_index()};
}
}

Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/graph/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "oneflow/core/job/task.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/common/auto_registration_factory.h"
#include "oneflow/core/memory/memory_zone.h"

namespace std {

Expand Down Expand Up @@ -97,7 +98,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
virtual void ToProto(TaskProto*) const;
virtual bool IsIndependent() const { return false; }
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual int64_t MemZoneId121() const;
virtual MemZoneId MemZoneId121() const;
bool BuildCtrlRegstDescIfNeed(TaskNode* dst_node, std::string* name);
RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node);
RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node, std::string* name);
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/memory/memory_zone.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "oneflow/core/memory/memory_zone.h"
#include "oneflow/core/common/device_type.pb.h"
// #include <climits>

namespace oneflow {
Expand All @@ -25,6 +26,8 @@ constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDe

const MemZoneId kCPUMemZoneId = MemZoneId{DeviceType::kCPU, MemZoneId::kCPUDeviceIndex};

const MemZoneId kInvalidMemZoneId = MemZoneId{DeviceType::kInvalidDevice, 0};

} // namespace

int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) {
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/memory/memory_zone.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);

extern const MemZoneId kCPUMemZoneId;
extern const MemZoneId kInvalidMemZoneId;

} // namespace oneflow

Expand Down

0 comments on commit 0868a61

Please sign in to comment.