[Unity] Add an API to create multiple kv caches with single allocation#15064
[Unity] Add an API to create multiple kv caches with single allocation#15064tqchen merged 2 commits intoapache:unityfrom
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
src/runtime/relax_vm/lm_support.cc
Outdated
| Array<AttentionKVCache> result; | ||
| for (int i = 0; i < num_caches; ++i) { | ||
| // Use DLManagedTensor to prevent underlying memory from being freed | ||
| DLManagedTensor* data_view = block_view.ToDLPack(); |
There was a problem hiding this comment.
Likely we can reuse the memory allocator(storage interface without having to go through DLPack
There was a problem hiding this comment.
There was a problem hiding this comment.
Thanks! I updated the code to use storage interface and it looks cleaner. But now it could print a warning message if the requested allocator type mismatches from the allocator that is created at VM initialization.
yzh119
left a comment
There was a problem hiding this comment.
LGTM and thanks for doing this, I just have a few minor comments.
| int init_fill_count, int num_caches) { | ||
| DLDataType dtype = init_data->dtype; | ||
|
|
||
| int64_t cache_size = (dtype.bits * dtype.lanes + 7) / 8; |
There was a problem hiding this comment.
So currently the dtype is smaller than one byte, then we would pad it to one byte, is that correct?
FYI: Flexgen uses 4-bit KV cache, we can support it later.
There was a problem hiding this comment.
I think it is fine for now. Since subbyte are usually packed manually(the dtype is i32)
There was a problem hiding this comment.
Thanks for the clarification, make sense to me.
This would be useful when creating multiple kv caches with the same shape. On A10G, compared to creating 64 kv caches separately in LLaMA from mlc-llm, doing a single allocation can save about 35 ms.
@tqchen @junrushao