diff --git a/include/tvm/runtime/relax_vm/memory_manager.h b/include/tvm/runtime/relax_vm/memory_manager.h index 9234e9151ce0..55952de3f8d1 100644 --- a/include/tvm/runtime/relax_vm/memory_manager.h +++ b/include/tvm/runtime/relax_vm/memory_manager.h @@ -97,6 +97,9 @@ class MemoryManager { */ static Allocator* GetAllocator(Device dev); + /*! \brief Clear the allocators. */ + static void Clear(); + private: MemoryManager() {} diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc index 7eedad2e56c0..2391bdc284fa 100644 --- a/src/runtime/relax_vm/memory_manager.cc +++ b/src/runtime/relax_vm/memory_manager.cc @@ -21,6 +21,7 @@ * \file tvm/runtime/relax_vm/memory_manager.cc * \brief Allocate and manage memory for the Relay VM. */ +#include #include #include @@ -169,6 +170,12 @@ Allocator* MemoryManager::GetAllocator(Device dev) { return it->second.get(); } +void MemoryManager::Clear() { + MemoryManager* m = MemoryManager::Global(); + std::lock_guard lock(m->mutex_); + m->allocators_.clear(); +} + runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev) { VerifyDataType(dtype); runtime::NDArray::Container* container = @@ -183,6 +190,8 @@ runtime::NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice d return runtime::NDArray(runtime::GetObjectPtr(container)); } +TVM_REGISTER_GLOBAL("vm.builtin.memory_manager.clear").set_body_typed(MemoryManager::Clear); + } // namespace relax_vm } // namespace runtime } // namespace tvm