Skip to content

Commit

Permalink
make bindings for memory planning usable in python
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Stahl committed Sep 17, 2021
1 parent 7bf631a commit 28c0954
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ class StorageInfo(Node):
type of the "virtual devices" the expressions are stored on,
and the sizes of each storage element."""

def __init__(self, sids, dev_types, sizes):
self.__init_handle_by_constructor__(_ffi_api.StorageInfo, sids, dev_types, sizes)

@property
def storage_ids(self):
return _ffi_api.StorageInfoStorageIds(self)
Expand All @@ -560,3 +563,13 @@ def device_types(self):
@property
def storage_sizes(self):
return _ffi_api.StorageInfoStorageSizes(self)


@tvm._ffi.register_object("relay.StaticMemoryPlan")
class StaticMemoryPlan(Node):
"""StaticMemoryPlan
The result of static memory planning."""

def __init__(self, expr_to_storage_info):
self.__init_handle_by_constructor__(_ffi_api.StaticMemoryPlan, expr_to_storage_info)
22 changes: 22 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceT
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relay.ir.StorageInfo")
.set_body_typed([](const Array<Integer>& sids, const Array<Integer>& dev_types,
const Array<Integer>& sizes_in_bytes) {
std::vector<int64_t> sids_v, sizes_v;
std::vector<DLDeviceType> dev_types_v;
for (auto s : sids) {
sids_v.push_back(s);
}
for (auto d : dev_types) {
dev_types_v.push_back(static_cast<DLDeviceType>(static_cast<int64_t>(d)));
}
for (auto s : sizes_in_bytes) {
sizes_v.push_back(s);
}
return StorageInfo(sids_v, dev_types_v, sizes_v);
});

TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) {
Array<tvm::Integer> ids;
for (auto id : si->storage_ids) {
Expand Down Expand Up @@ -73,6 +90,11 @@ StaticMemoryPlan::StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info)
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan")
.set_body_typed([](const Map<Expr, StorageInfo>& expr_to_storage_info) {
return StaticMemoryPlan(expr_to_storage_info);
});

int64_t CalculateRelayExprSizeBytes(const Type& expr_type) {
if (expr_type->IsInstance<TupleTypeNode>()) {
auto tuple_type = Downcast<TupleType>(expr_type);
Expand Down

0 comments on commit 28c0954

Please sign in to comment.