Skip to content

Commit

Permalink
move mem zone id source code
Browse files Browse the repository at this point in the history
  • Loading branch information
leaves-zwx committed May 28, 2021
1 parent 7550a12 commit 3859fc2
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 68 deletions.
41 changes: 0 additions & 41 deletions oneflow/core/common/id_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,42 +130,6 @@ class TaskId {
task_index_t task_index_;
};

class MemZoneId {
public:
using device_index_t = uint32_t;

constexpr static device_index_t kCPUDeviceIndex = 0;
constexpr static size_t kDeviceTypeBits = 5;
constexpr static size_t kDeviceIndexBits = 7;
constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1};
constexpr static device_index_t kMaxDeviceIndex =
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1};

MemZoneId(DeviceType device_type, device_index_t device_index)
: device_type_(device_type), device_index_(device_index) {
CHECK_LE(static_cast<size_t>(device_type), kMaxDeviceTypeVal);
CHECK_LE(device_index, kMaxDeviceIndex);
}

DeviceType device_type() const { return device_type_; }
device_index_t device_index() const { return device_index_; }

bool operator==(const MemZoneId& rhs) const {
return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_;
}
bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); }

size_t hash() const {
size_t hash = std::hash<size_t>{}(static_cast<size_t>(device_type_));
HashCombine(&hash, std::hash<device_index_t>{}(device_index_));
return hash;
}

private:
DeviceType device_type_;
device_index_t device_index_;
};

} // namespace oneflow

namespace std {
Expand All @@ -185,11 +149,6 @@ struct hash<oneflow::TaskId> {
size_t operator()(const oneflow::TaskId& task_id) const { return task_id.hash(); }
};

template<>
struct hash<oneflow::MemZoneId> {
size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); }
};

} // namespace std

#endif // ONEFLOW_CORE_COMMON_ID_UTIL_H_
24 changes: 0 additions & 24 deletions oneflow/core/graph/id_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,28 +119,4 @@ TaskId DeserializeTaskIdFromInt64(int64_t task_id_val) {
return TaskId{stream_id, static_cast<TaskId::task_index_t>(task_index)};
}

namespace mem_zone_id_const {

constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits;
constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1)
<< kMemZoneIdDeviceTypeShift;
constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1;

} // namespace mem_zone_id_const

int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) {
int64_t id = static_cast<int64_t>(mem_zone_id.device_index());
id |= static_cast<int64_t>(mem_zone_id.device_type())
<< mem_zone_id_const::kMemZoneIdDeviceTypeShift;
return id;
}

MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t device_type = (mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceTypeInt64Mask)
>> mem_zone_id_const::kMemZoneIdDeviceTypeShift;
int64_t device_index = mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceIndexInt64Mask;
return MemZoneId(static_cast<DeviceType>(device_type),
static_cast<MemZoneId::device_index_t>(device_index));
}

} // namespace oneflow
3 changes: 0 additions & 3 deletions oneflow/core/graph/id_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ StreamId DeserializeStreamIdFromInt64(int64_t);
int64_t SerializeTaskIdToInt64(const TaskId&);
TaskId DeserializeTaskIdFromInt64(int64_t);

int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_ID_SERIALIZATION_H_
43 changes: 43 additions & 0 deletions oneflow/core/memory/memory_zone.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "oneflow/core/memory/memory_zone.h"
// #include <climits>

namespace oneflow {

// TaskId encode (may be extended to 128 bit in future)
// | rank | device_type | device_index | |
// | ----------- 19 ----------- | ---- 5 ---- | ----- 7 ----- | |
// | DeviceId | stream_index | |
// | ------------------------- 31 --------------------------- | ---- 12 ---- | |
// | StreamId | task_index |
// | -------------------------------- 43 ----------------------------------- | --- 21 --- |
// | TaskId |
// | ----------------------------------- 64 bit ----------------------------------------- |

namespace {

// constexpr size_t kInt64Bits = sizeof(int64_t) * CHAR_BIT;

constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits;

constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1)
<< kMemZoneIdDeviceTypeShift;
constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1;

const MemZoneId kCPUMemZoneId = MemZoneId{DeviceType::kCPU, MemZoneId::kCPUDeviceIndex};

} // namespace

int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) {
int64_t id = static_cast<int64_t>(mem_zone_id.device_index());
id |= static_cast<int64_t>(mem_zone_id.device_type()) << kMemZoneIdDeviceTypeShift;
return id;
}

MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t device_type = (mem_zone_id & kMemZoneIdDeviceTypeInt64Mask) >> kMemZoneIdDeviceTypeShift;
int64_t device_index = mem_zone_id & kMemZoneIdDeviceIndexInt64Mask;
return MemZoneId(static_cast<DeviceType>(device_type),
static_cast<MemZoneId::device_index_t>(device_index));
}

} // namespace oneflow
61 changes: 61 additions & 0 deletions oneflow/core/memory/memory_zone.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#ifndef ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_
#define ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_

#include "oneflow/core/common/util.h"
#include "oneflow/core/common/device_type.pb.h"

namespace oneflow {

class MemZoneId {
public:
using device_index_t = uint32_t;

constexpr static device_index_t kCPUDeviceIndex = 0;
constexpr static size_t kDeviceTypeBits = 5;
constexpr static size_t kDeviceIndexBits = 7;
constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1};
constexpr static device_index_t kMaxDeviceIndex =
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1};

MemZoneId(DeviceType device_type, device_index_t device_index)
: device_type_(device_type), device_index_(device_index) {
CHECK_LE(static_cast<size_t>(device_type), kMaxDeviceTypeVal);
CHECK_LE(device_index, kMaxDeviceIndex);
}

DeviceType device_type() const { return device_type_; }
device_index_t device_index() const { return device_index_; }

bool operator==(const MemZoneId& rhs) const {
return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_;
}
bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); }

size_t hash() const {
size_t hash = std::hash<size_t>{}(static_cast<size_t>(device_type_));
HashCombine(&hash, std::hash<device_index_t>{}(device_index_));
return hash;
}

private:
DeviceType device_type_;
device_index_t device_index_;
};

int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);

extern const MemZoneId kCPUMemZoneId;

} // namespace oneflow

namespace std {

template<>
struct hash<oneflow::MemZoneId> {
size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); }
};

} // namespace std

#endif // ONEFLOW_CORE_MEMORY_MEMORY_ZONE_H_

0 comments on commit 3859fc2

Please sign in to comment.