From 3859fc2a0fcda2fb23e57e886a0e3f1c0833d111 Mon Sep 17 00:00:00 2001 From: leaves-zwx Date: Fri, 28 May 2021 16:27:10 +0800 Subject: [PATCH] move mem zone id source code --- oneflow/core/common/id_util.h | 41 ----------------- oneflow/core/graph/id_serialization.cpp | 24 ---------- oneflow/core/graph/id_serialization.h | 3 -- oneflow/core/memory/memory_zone.cpp | 43 +++++++++++++++++ oneflow/core/memory/memory_zone.h | 61 +++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 68 deletions(-) create mode 100644 oneflow/core/memory/memory_zone.cpp create mode 100644 oneflow/core/memory/memory_zone.h diff --git a/oneflow/core/common/id_util.h b/oneflow/core/common/id_util.h index 55236a818a2..ace96383c2c 100644 --- a/oneflow/core/common/id_util.h +++ b/oneflow/core/common/id_util.h @@ -130,42 +130,6 @@ class TaskId { task_index_t task_index_; }; -class MemZoneId { - public: - using device_index_t = uint32_t; - - constexpr static device_index_t kCPUDeviceIndex = 0; - constexpr static size_t kDeviceTypeBits = 5; - constexpr static size_t kDeviceIndexBits = 7; - constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1}; - constexpr static device_index_t kMaxDeviceIndex = - (device_index_t{1} << kDeviceIndexBits) - device_index_t{1}; - - MemZoneId(DeviceType device_type, device_index_t device_index) - : device_type_(device_type), device_index_(device_index) { - CHECK_LE(static_cast(device_type), kMaxDeviceTypeVal); - CHECK_LE(device_index, kMaxDeviceIndex); - } - - DeviceType device_type() const { return device_type_; } - device_index_t device_index() const { return device_index_; } - - bool operator==(const MemZoneId& rhs) const { - return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_; - } - bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); } - - size_t hash() const { - size_t hash = std::hash{}(static_cast(device_type_)); - HashCombine(&hash, std::hash{}(device_index_)); - return hash; - } - - private: - DeviceType device_type_; - device_index_t device_index_; -}; - } // namespace oneflow namespace std { @@ -185,11 +149,6 @@ struct hash { size_t operator()(const oneflow::TaskId& task_id) const { return task_id.hash(); } }; -template<> -struct hash { - size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); } -}; - } // namespace std #endif // ONEFLOW_CORE_COMMON_ID_UTIL_H_ diff --git a/oneflow/core/graph/id_serialization.cpp b/oneflow/core/graph/id_serialization.cpp index c161de17737..1b6fb4d249a 100644 --- a/oneflow/core/graph/id_serialization.cpp +++ b/oneflow/core/graph/id_serialization.cpp @@ -119,28 +119,4 @@ TaskId DeserializeTaskIdFromInt64(int64_t task_id_val) { return TaskId{stream_id, static_cast(task_index)}; } -namespace mem_zone_id_const { - -constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits; -constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1) - << kMemZoneIdDeviceTypeShift; -constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1; - -} // namespace mem_zone_id_const - -int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) { - int64_t id = static_cast(mem_zone_id.device_index()); - id |= static_cast(mem_zone_id.device_type()) - << mem_zone_id_const::kMemZoneIdDeviceTypeShift; - return id; -} - -MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) { - int64_t device_type = (mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceTypeInt64Mask) - >> mem_zone_id_const::kMemZoneIdDeviceTypeShift; - int64_t device_index = mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceIndexInt64Mask; - return MemZoneId(static_cast(device_type), - static_cast(device_index)); -} - } // namespace oneflow diff --git a/oneflow/core/graph/id_serialization.h b/oneflow/core/graph/id_serialization.h index 5a3f75ebf2e..e0188f633dd 100644 --- a/oneflow/core/graph/id_serialization.h +++ b/oneflow/core/graph/id_serialization.h @@ -26,9 +26,6 @@ StreamId DeserializeStreamIdFromInt64(int64_t); int64_t SerializeTaskIdToInt64(const TaskId&); TaskId DeserializeTaskIdFromInt64(int64_t); -int64_t EncodeMemZoneIdToInt64(const MemZoneId&); -MemZoneId DecodeMemZoneIdFromInt64(int64_t); - } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_ID_SERIALIZATION_H_ diff --git a/oneflow/core/memory/memory_zone.cpp b/oneflow/core/memory/memory_zone.cpp new file mode 100644 index 00000000000..5ec072fe39d --- /dev/null +++ b/oneflow/core/memory/memory_zone.cpp @@ -0,0 +1,43 @@ +#include "oneflow/core/memory/memory_zone.h" +// #include + +namespace oneflow { + +// TaskId encode (may be extended to 128 bit in future) +// | rank | device_type | device_index | | +// | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | | +// | DeviceId | stream_index | | +// | ------------------------- 31 --------------------------- | ---- 12 ---- | | +// | StreamId | task_index | +// | -------------------------------- 43 ----------------------------------- | --- 21 --- | +// | TaskId | +// | ----------------------------------- 64 bit ----------------------------------------- | + +namespace { + +// constexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT; + +constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits; + +constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1) + << kMemZoneIdDeviceTypeShift; +constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1; + +const MemZoneId kCPUMemZoneId = MemZoneId{DeviceType::kCPU, MemZoneId::kCPUDeviceIndex}; + +} // namespace + +int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) { + int64_t id = static_cast(mem_zone_id.device_index()); + id |= static_cast(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift; + return id; +} + +MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) { + int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift; + int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask; + return MemZoneId(static_cast(device_type), + static_cast(device_index)); +} + +} // namespace oneflow diff --git a/oneflow/core/memory/memory_zone.h b/oneflow/core/memory/memory_zone.h new file mode 100644 index 00000000000..9791c64d424 --- /dev/null +++ b/oneflow/core/memory/memory_zone.h @@ -0,0 +1,61 @@ +#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ +#define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/device_type.pb.h" + +namespace oneflow { + +class MemZoneId { + public: + using device_index_t = uint32_t; + + constexpr static device_index_t kCPUDeviceIndex = 0; + constexpr static size_t kDeviceTypeBits = 5; + constexpr static size_t kDeviceIndexBits = 7; + constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1}; + constexpr static device_index_t kMaxDeviceIndex = + (device_index_t{1} << kDeviceIndexBits) - device_index_t{1}; + + MemZoneId(DeviceType device_type, device_index_t device_index) + : device_type_(device_type), device_index_(device_index) { + CHECK_LE(static_cast(device_type), kMaxDeviceTypeVal); + CHECK_LE(device_index, kMaxDeviceIndex); + } + + DeviceType device_type() const { return device_type_; } + device_index_t device_index() const { return device_index_; } + + bool operator==(const MemZoneId& rhs) const { + return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_; + } + bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); } + + size_t hash() const { + size_t hash = std::hash{}(static_cast(device_type_)); + HashCombine(&hash, std::hash{}(device_index_)); + return hash; + } + + private: + DeviceType device_type_; + device_index_t device_index_; +}; + +int64_t EncodeMemZoneIdToInt64(const MemZoneId&); +MemZoneId DecodeMemZoneIdFromInt64(int64_t); + +extern const MemZoneId kCPUMemZoneId; + +} // namespace oneflow + +namespace std { + +template<> +struct hash { + size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); } +}; + +} // namespace std + +#endif // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_