-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7550a12
commit 3859fc2
Showing
5 changed files
with
104 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include "oneflow/core/memory/memory_zone.h" | ||
// #include <climits> | ||
|
||
namespace oneflow { | ||
|
||
// TaskId encode (may be extended to 128 bit in future) | ||
// | rank | device_type | device_index | | | ||
// | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | | | ||
// | DeviceId | stream_index | | | ||
// | ------------------------- 31 --------------------------- | ---- 12 ---- | | | ||
// | StreamId | task_index | | ||
// | -------------------------------- 43 ----------------------------------- | --- 21 --- | | ||
// | TaskId | | ||
// | ----------------------------------- 64 bit ----------------------------------------- | | ||
|
||
namespace { | ||
|
||
// constexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT; | ||
|
||
constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits; | ||
|
||
constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1) | ||
<< kMemZoneIdDeviceTypeShift; | ||
constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1; | ||
|
||
const MemZoneId kCPUMemZoneId = MemZoneId{DeviceType::kCPU, MemZoneId::kCPUDeviceIndex}; | ||
|
||
} // namespace | ||
|
||
int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) { | ||
int64_t id = static_cast<int64_t>(mem_zone_id.device_index()); | ||
id |= static_cast<int64_t>(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift; | ||
return id; | ||
} | ||
|
||
MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) { | ||
int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift; | ||
int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask; | ||
return MemZoneId(static_cast<DeviceType>(device_type), | ||
static_cast<MemZoneId::device_index_t>(device_index)); | ||
} | ||
|
||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ | ||
#define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ | ||
|
||
#include "oneflow/core/common/util.h" | ||
#include "oneflow/core/common/device_type.pb.h" | ||
|
||
namespace oneflow { | ||
|
||
class MemZoneId { | ||
public: | ||
using device_index_t = uint32_t; | ||
|
||
constexpr static device_index_t kCPUDeviceIndex = 0; | ||
constexpr static size_t kDeviceTypeBits = 5; | ||
constexpr static size_t kDeviceIndexBits = 7; | ||
constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1}; | ||
constexpr static device_index_t kMaxDeviceIndex = | ||
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1}; | ||
|
||
MemZoneId(DeviceType device_type, device_index_t device_index) | ||
: device_type_(device_type), device_index_(device_index) { | ||
CHECK_LE(static_cast<size_t>(device_type), kMaxDeviceTypeVal); | ||
CHECK_LE(device_index, kMaxDeviceIndex); | ||
} | ||
|
||
DeviceType device_type() const { return device_type_; } | ||
device_index_t device_index() const { return device_index_; } | ||
|
||
bool operator==(const MemZoneId& rhs) const { | ||
return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_; | ||
} | ||
bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); } | ||
|
||
size_t hash() const { | ||
size_t hash = std::hash<size_t>{}(static_cast<size_t>(device_type_)); | ||
HashCombine(&hash, std::hash<device_index_t>{}(device_index_)); | ||
return hash; | ||
} | ||
|
||
private: | ||
DeviceType device_type_; | ||
device_index_t device_index_; | ||
}; | ||
|
||
int64_t EncodeMemZoneIdToInt64(const MemZoneId&); | ||
MemZoneId DecodeMemZoneIdFromInt64(int64_t); | ||
|
||
extern const MemZoneId kCPUMemZoneId; | ||
|
||
} // namespace oneflow | ||
|
||
namespace std { | ||
|
||
template<> | ||
struct hash<oneflow::MemZoneId> { | ||
size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); } | ||
}; | ||
|
||
} // namespace std | ||
|
||
#endif // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_ |