From 5be91a05501cfe11ef7092a82fb967687f151a17 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 8 Jun 2023 13:16:51 -0400 Subject: [PATCH 1/2] Add a batch create API for kv cache --- src/runtime/relax_vm/lm_support.cc | 47 ++++++++++++++++++++++ tests/python/relax/test_runtime_builtin.py | 28 ++++++++++++- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index cfc596d47653..765b34c30113 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -167,12 +167,59 @@ class AttentionKVCache : public ObjectRef { TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj); +/*! + * \brief Create multiple kv caches with same shape, from single memory allocation. + * \param init_data The initial data to put into the cache. Ignored if init_fill_count is + * less than 0. + * \param reserve_shape The shape of cache. + * \param init_fill_count The initial row to fill into + * the cache. + * \param num_caches Number of caches to create. + */ +Array CreateMultipleKVCaches(NDArray init_data, ShapeTuple reserve_shape, + int init_fill_count, int num_caches) { + DLDataType dtype = init_data->dtype; + + int64_t cache_size = (dtype.bits * dtype.lanes + 7) / 8; + for (const auto dim : reserve_shape) { + cache_size *= dim; + } + + // Add padding to make each cache align to kAllocAlignment + using tvm::runtime::kAllocAlignment; + int64_t padding = (kAllocAlignment - cache_size % kAllocAlignment) % kAllocAlignment; + int64_t cache_offset = cache_size + padding; + + auto block = NDArray::Empty(ShapeTuple({cache_offset * num_caches}), dtype, init_data->device); + auto block_view = block.CreateView(reserve_shape, dtype); + + Array result; + for (int i = 0; i < num_caches; ++i) { + // Use DLManagedTensor to prevent underlying memory from being freed + DLManagedTensor* data_view = block_view.ToDLPack(); + data_view->dl_tensor.data = (void*)((char*)(data_view->dl_tensor.data) + i * cache_offset); + + auto c = make_object(); + c->data = NDArray::FromDLPack(data_view); + c->fill_count = 0; + if (init_fill_count > 0) { + c->Append(init_data); + c->fill_count = init_fill_count; + } + result.push_back(AttentionKVCache(c)); + } + return result; +} + //------------------------------------------------- // Register runtime functions //------------------------------------------------- TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create") .set_body_typed(AttentionKVCache::Create); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create_multiple") + .set_body_typed(CreateMultipleKVCaches); + AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value) { cache->Update(value); return cache; diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py index d25841a71f6b..682e9d712d58 100644 --- a/tests/python/relax/test_runtime_builtin.py +++ b/tests/python/relax/test_runtime_builtin.py @@ -158,9 +158,9 @@ def test_attention_kv_cache(): fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([2, 2]), 0) - num_steps = 0 + num_steps = 2 for i in range(num_steps): - cache = fappend(cache, tvm.nd.array(i * np.ones((1, 2).astype("int32")))) + cache = fappend(cache, tvm.nd.array(i * np.ones((1, 2)).astype("int32"))) res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy() for i in range(num_steps): @@ -168,6 +168,30 @@ def test_attention_kv_cache(): assert res[i][1] == i +def test_attention_kv_cache_create_multiple(): + fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create_multiple") + fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append") + fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view") + + num_caches = 4 + cache_group = fcreate( + tvm.nd.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([7, 2]), 0, num_caches + ) + + num_steps = 7 + for i in range(num_steps): + for cache_index in range(num_caches): + fappend( + cache_group[cache_index], + tvm.nd.array(i * cache_index * np.ones((1, 2)).astype("int32")), + ) + res = fview(cache_group[cache_index], tvm.runtime.ShapeTuple((i + 1, 2))).numpy() + # Also verify that the old values aren't corrupted + for j in range(i): + assert res[j][0] == j * cache_index + assert res[j][1] == j * cache_index + + def test_ndarray_cache(): fload = tvm.get_global_func("vm.builtin.ndarray_cache.load") fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") From ec4fc48572c90e3bbf7b490c22b7c8742999a66e Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 9 Jun 2023 14:14:52 -0400 Subject: [PATCH 2/2] Use storage api --- src/runtime/relax_vm/lm_support.cc | 12 +++++------- src/runtime/relax_vm/memory_manager.cc | 7 +++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 765b34c30113..9b14161e67db 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -190,17 +191,14 @@ Array CreateMultipleKVCaches(NDArray init_data, ShapeTuple res int64_t padding = (kAllocAlignment - cache_size % kAllocAlignment) % kAllocAlignment; int64_t cache_offset = cache_size + padding; - auto block = NDArray::Empty(ShapeTuple({cache_offset * num_caches}), dtype, init_data->device); - auto block_view = block.CreateView(reserve_shape, dtype); + Storage storage = + Storage(MemoryManager::GetOrCreateAllocator(init_data->device, AllocatorType::kNaive) + ->Alloc(cache_offset * num_caches, kAllocAlignment, dtype)); Array result; for (int i = 0; i < num_caches; ++i) { - // Use DLManagedTensor to prevent underlying memory from being freed - DLManagedTensor* data_view = block_view.ToDLPack(); - data_view->dl_tensor.data = (void*)((char*)(data_view->dl_tensor.data) + i * cache_offset); - auto c = make_object(); - c->data = NDArray::FromDLPack(data_view); + c->data = storage->AllocNDArray(i * cache_offset, reserve_shape, dtype); c->fill_count = 0; if (init_fill_count > 0) { c->Append(init_data); diff --git a/src/runtime/relax_vm/memory_manager.cc b/src/runtime/relax_vm/memory_manager.cc index 339045f515cf..7eedad2e56c0 100644 --- a/src/runtime/relax_vm/memory_manager.cc +++ b/src/runtime/relax_vm/memory_manager.cc @@ -28,6 +28,7 @@ #include "naive_allocator.h" #include "pooled_allocator.h" +#include "tvm/runtime/memory.h" namespace tvm { namespace runtime { @@ -58,6 +59,12 @@ void StorageObj::Deleter(Object* obj) { delete ptr; } +Storage::Storage(Buffer buffer) { + auto n = make_object(); + n->buffer = std::move(buffer); + data_ = std::move(n); +} + inline void VerifyDataType(DLDataType dtype) { ICHECK_GE(dtype.lanes, 1); if (dtype.code == kDLFloat) {