From 15770f316a887d50addd771e38a1e4c25d4beaa1 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 18 May 2023 18:10:36 -0700 Subject: [PATCH] [Unity] Support clear global memory allocators This PR supports clearing up all the allocated memory in among relax VMs. Prior to this PR, all the allocated memory are managed in the pool of memory manager. The allocated memory in the pool is never freed and the pool size always goes up monotonically. While good to save time of memory allocation, in some cases (e.g., on mobile phones which may have running memory limit) we need to clear the pool and free all the memory in order to prevent the pool from endlessly growing up and some of allocated memory not being effectively utilized. Therefore, this PR introduces a PackedFunc that helps clear the pool. --- include/tvm/runtime/relax_vm/memory_manager.h | 3 +++ src/runtime/relax_vm/memory_manager.cc | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/include/tvm/runtime/relax_vm/memory_manager.h b/include/tvm/runtime/relax_vm/memory_manager.h index 9234e9151ce0..55952de3f8d1 100644 --- a/include/tvm/runtime/relax_vm/memory_manager.h +++ b/include/tvm/runtime/relax_vm/memory_manager.h @@ -97,6 +97,9 @@ class MemoryManager { */ static Allocator* GetAllocator(Device dev); + /*! \brief Clear the allocators. */ + static void Clear(); + private: MemoryManager() {} diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc index 7eedad2e56c0..2391bdc284fa 100644 --- a/src/runtime/relax_vm/memory_manager.cc +++ b/src/runtime/relax_vm/memory_manager.cc @@ -21,6 +21,7 @@ * \file tvm/runtime/relax_vm/memory_manager.cc * \brief Allocate and manage memory for the Relay VM. */ +#include #include #include @@ -169,6 +170,12 @@ Allocator* MemoryManager::GetAllocator(Device dev) { return it->second.get(); } +void MemoryManager::Clear() { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + m->allocators_.clear(); +} + runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev) { VerifyDataType(dtype); runtime::NDArray::Container* container = @@ -183,6 +190,8 @@ runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice d return runtime::NDArray(runtime::GetObjectPtr(container)); } +TVM_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); + } // namespace relax_vm } // namespace runtime } // namespace tvm