Skip to content

Commit

Permalink
MemZoneId
Browse files Browse the repository at this point in the history
  • Loading branch information
leaves-zwx committed May 24, 2021
1 parent 49d53fe commit 7550a12
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
41 changes: 41 additions & 0 deletions oneflow/core/common/id_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,42 @@ class TaskId {
task_index_t task_index_;
};

class MemZoneId {
public:
using device_index_t = uint32_t;

constexpr static device_index_t kCPUDeviceIndex = 0;
constexpr static size_t kDeviceTypeBits = 5;
constexpr static size_t kDeviceIndexBits = 7;
constexpr static size_t kMaxDeviceTypeVal = (size_t{1} << kDeviceTypeBits) - size_t{1};
constexpr static device_index_t kMaxDeviceIndex =
(device_index_t{1} << kDeviceIndexBits) - device_index_t{1};

MemZoneId(DeviceType device_type, device_index_t device_index)
: device_type_(device_type), device_index_(device_index) {
CHECK_LE(static_cast<size_t>(device_type), kMaxDeviceTypeVal);
CHECK_LE(device_index, kMaxDeviceIndex);
}

DeviceType device_type() const { return device_type_; }
device_index_t device_index() const { return device_index_; }

bool operator==(const MemZoneId& rhs) const {
return device_type_ == rhs.device_type_ && device_index_ == rhs.device_index_;
}
bool operator!=(const MemZoneId& rhs) const { return !(*this == rhs); }

size_t hash() const {
size_t hash = std::hash<size_t>{}(static_cast<size_t>(device_type_));
HashCombine(&hash, std::hash<device_index_t>{}(device_index_));
return hash;
}

private:
DeviceType device_type_;
device_index_t device_index_;
};

} // namespace oneflow

namespace std {
Expand All @@ -149,6 +185,11 @@ struct hash<oneflow::TaskId> {
size_t operator()(const oneflow::TaskId& task_id) const { return task_id.hash(); }
};

template<>
struct hash<oneflow::MemZoneId> {
size_t operator()(const oneflow::MemZoneId& mem_zone_id) const { return mem_zone_id.hash(); }
};

} // namespace std

#endif // ONEFLOW_CORE_COMMON_ID_UTIL_H_
24 changes: 24 additions & 0 deletions oneflow/core/graph/id_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,28 @@ TaskId DeserializeTaskIdFromInt64(int64_t task_id_val) {
return TaskId{stream_id, static_cast<TaskId::task_index_t>(task_index)};
}

namespace mem_zone_id_const {

constexpr size_t kMemZoneIdDeviceTypeShift = MemZoneId::kDeviceIndexBits;
constexpr int64_t kMemZoneIdDeviceTypeInt64Mask = ((int64_t{1} << MemZoneId::kDeviceTypeBits) - 1)
<< kMemZoneIdDeviceTypeShift;
constexpr int64_t kMemZoneIdDeviceIndexInt64Mask = (int64_t{1} << MemZoneId::kDeviceIndexBits) - 1;

} // namespace mem_zone_id_const

int64_t EncodeMemZoneIdToInt64(const MemZoneId& mem_zone_id) {
int64_t id = static_cast<int64_t>(mem_zone_id.device_index());
id |= static_cast<int64_t>(mem_zone_id.device_type())
<< mem_zone_id_const::kMemZoneIdDeviceTypeShift;
return id;
}

MemZoneId DecodeMemZoneIdFromInt64(int64_t mem_zone_id) {
int64_t device_type = (mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceTypeInt64Mask)
>> mem_zone_id_const::kMemZoneIdDeviceTypeShift;
int64_t device_index = mem_zone_id & mem_zone_id_const::kMemZoneIdDeviceIndexInt64Mask;
return MemZoneId(static_cast<DeviceType>(device_type),
static_cast<MemZoneId::device_index_t>(device_index));
}

} // namespace oneflow
4 changes: 4 additions & 0 deletions oneflow/core/graph/id_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ namespace oneflow {

int64_t SerializeStreamIdToInt64(const StreamId&);
StreamId DeserializeStreamIdFromInt64(int64_t);

int64_t SerializeTaskIdToInt64(const TaskId&);
TaskId DeserializeTaskIdFromInt64(int64_t);

int64_t EncodeMemZoneIdToInt64(const MemZoneId&);
MemZoneId DecodeMemZoneIdFromInt64(int64_t);

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_ID_SERIALIZATION_H_

0 comments on commit 7550a12

Please sign in to comment.