Skip to content

Commit

Permalink
[KVCache] Unlimited depth blocks (#17100)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 committed Jun 18, 2024
1 parent 5bfca2e commit 675a023
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 86 deletions.
175 changes: 111 additions & 64 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,6 @@ struct Sequence {
}
block_ptr = block.parent_idx;
}
CHECK_LE(depth, kPagedKVCacheMaxBlockDepth)
<< "Paged KV cache supports one sequence to reuse " << kPagedKVCacheMaxBlockDepth
<< " prefixes (the fork depth) at most. However, the given sequence has fork depth "
<< depth;
}

std::vector<int32_t> GetBlockTrace(const std::vector<Block>& global_block_pool) const {
Expand Down Expand Up @@ -1199,44 +1195,38 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "The parent sequence's token tree computed in the last round of forward has not been "
"committed with accepted nodes.";

if (fork_pos == -1) {
fork_pos = parent_it->second.seq_length;
}

if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0 &&
global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) {
// To enable the parent sequence to continue decode after the fork,
// we add a new empty block at the end of the parent sequence.
// So the new decoded KV data will go into the new block.
int32_t new_block_idx = GetFreeBlock();
global_block_pool_[new_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_block_idx].parent_idx = parent_it->second.last_block_idx;
global_block_pool_[new_block_idx].external_ref_cnt = 1;
parent_it->second.last_block_idx = new_block_idx;
}

int32_t child_block_idx = GetFreeBlock();
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
if (!global_block_pool_[parent_block_idx].seq_length) {
// If parent ends with empty block, fork from parent's parent block
parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
}
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (parent_block_idx == parent_it->second.last_block_idx &&
global_block_pool_[parent_block_idx].seq_length) {
// To enable the parent sequence to continue decode after the fork,
// we add a new empty block at the end of the parent sequence.
// So the new decoded KV data will go into the new block.
int32_t new_parent_block_idx = GetFreeBlock();
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
global_block_pool_[new_parent_block_idx].external_ref_cnt = 1;
parent_it->second.last_block_idx = new_parent_block_idx;
}
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
int64_t in_block_offset = fork_pos;
int32_t forked_block_idx = -1;
for (int32_t block_idx : trace) {
if (in_block_offset < global_block_pool_[block_idx].seq_length) {
forked_block_idx = block_idx;
break;
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
int64_t in_block_offset = fork_pos;
for (int32_t forked_block_idx : trace) {
if (forked_block_idx != trace.back()) {
CHECK_GT(global_block_pool_[forked_block_idx].seq_length, 0);
CHECK_EQ(global_block_pool_[forked_block_idx].seq_length % page_size_, 0);
if (global_block_pool_[forked_block_idx].seq_length <= in_block_offset) {
in_block_offset -= global_block_pool_[forked_block_idx].seq_length;
continue;
}
in_block_offset -= global_block_pool_[block_idx].seq_length;
}
int32_t in_page_offset = in_block_offset % page_size_;
int32_t moved_offset = in_block_offset - in_page_offset;
if (moved_offset == 0) {
int32_t moved_pages = moved_offset / page_size_;
if (moved_pages == 0) {
// Forked at the first page in block
int32_t parent_block_idx = global_block_pool_[forked_block_idx].parent_idx;
if (parent_block_idx != -1) {
Expand All @@ -1256,8 +1246,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

// Move common leading pages to new parent block
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
auto last_page =
global_block_pool_[forked_block_idx].page_ids.begin() + moved_offset / page_size_;
auto last_page = global_block_pool_[forked_block_idx].page_ids.begin() + moved_pages;
global_block_pool_[parent_block_idx].page_ids = {first_page, last_page};
global_block_pool_[forked_block_idx].page_ids.erase(first_page, last_page);

Expand All @@ -1280,6 +1269,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
global_block_pool_[child_block_idx].page_ids.push_back(tgt_page_id);
CopySinglePage(src_page_id, tgt_page_id, in_page_offset);
}
break;
}
// Create the child sequence with the child block.
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
Expand Down Expand Up @@ -1496,19 +1486,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
is_chain_ = true;
}

std::vector<std::vector<int32_t>> block_ids_on_depths = GetBlockIdsOnDepth(sequences);
num_depths_ = block_ids_on_depths.size();
auto [block_ids_on_depths, trailing_blocks] = GetBlockIdsOnDepth(sequences);
num_depths_ =
std::min(static_cast<int>(block_ids_on_depths.size()), kPagedKVCacheMaxBlockDepth);
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);

std::vector<std::vector<std::pair<int32_t, int32_t>>> chunked_block_ids_arr;
chunked_block_ids_arr.reserve(num_depths_);
use_decode_kernel_.clear();
for (int d = 0; d < num_depths_; ++d) {
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(block_ids_on_depths[d]);
// We force the blocks at maximum depth not to coalesce, so that it can be concatenated with
// trailing exceeding blocks.
auto [chunked_block_ids, use_decode_kernel] = GetChunkedBlockIds(
block_ids_on_depths[d], /*enable_coalesce=*/d != kPagedKVCacheMaxBlockDepth - 1);
chunked_block_ids_arr.push_back(chunked_block_ids);
use_decode_kernel_.push_back(use_decode_kernel);
}

if (num_depths_ == kPagedKVCacheMaxBlockDepth) {
// Since we force the blocks at maximum depth not to coalesce, the output blocks at maximum
// depth must have the same size as current batch.
CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
}

append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0];
if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
Expand Down Expand Up @@ -1536,7 +1536,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
k_rope_pos_offset_h.clear();
qo_indptr_h.push_back(0);
page_indptr_h.push_back(0);
for (const auto& [block_id, chunk_append_length] : chunked_block_ids_arr[d]) {
for (int i = 0; i < static_cast<int>(chunked_block_ids_arr[d].size()); ++i) {
const auto& [block_id, chunk_append_length] = chunked_block_ids_arr[d][i];
qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
if (block_id == -1) {
page_indptr_h.push_back(page_indptr_h.back());
Expand All @@ -1545,19 +1546,53 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
sink_size_h.push_back(0);
k_rope_pos_offset_h.push_back(0);
} else {
const Block& block = global_block_pool_[block_id];
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
if (d < kPagedKVCacheMaxBlockDepth - 1) {
// Blocks not at maximum depth
const Block& block = global_block_pool_[block_id];
page_indptr_h.push_back(page_indptr_h.back() + block.page_ids.size());
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
last_page_len_h.push_back(
block.seq_length == 0
? 0
: (block.seq_length - block.sink_length + block.sliding_window_offset - 1) %
page_size_ +
1);
sliding_window_offset_h.push_back(block.sliding_window_offset);
sink_size_h.push_back(block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
} else {
// Blocks at maximum depth
const Block& block = global_block_pool_[block_id];
int32_t num_pages = static_cast<int32_t>(block.page_ids.size());
int32_t total_seq_length = static_cast<int32_t>(block.seq_length);
int32_t last_block_id = block_id;
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
for (int32_t id : trailing_blocks[i]) {
// Collect trailing blocks if available
const Block& block = global_block_pool_[id];
for (int32_t page_id : block.page_ids) {
page_indices_h.push_back(page_id);
}
num_pages += block.page_ids.size();
total_seq_length += block.seq_length;
last_block_id = id;
}
page_indptr_h.push_back(page_indptr_h.back() + num_pages);
const Block& last_block = global_block_pool_[last_block_id];
last_page_len_h.push_back(total_seq_length == 0
? 0
: (total_seq_length - last_block.sink_length +
last_block.sliding_window_offset - 1) %
page_size_ +
1);
sliding_window_offset_h.push_back(last_block.sliding_window_offset);
sink_size_h.push_back(last_block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
}
last_page_len_h.push_back(block.seq_length == 0 ? 0
: (block.seq_length - block.sink_length +
block.sliding_window_offset - 1) %
page_size_ +
1);
sliding_window_offset_h.push_back(block.sliding_window_offset);
sink_size_h.push_back(block.sink_length);
k_rope_pos_offset_h.push_back(block.start_pos);
}
}
}
Expand Down Expand Up @@ -2041,22 +2076,34 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
/*!
* \brief For the given list of sequences, check the block trace of
* each sequence, and return the blocks ids used by the sequences
* on each depth.
* on each depth. And if the depth is larger than the kPagedKVCacheMaxBlockDepth,
* the exceeding blocks will concatenate and output separately.
* More precisely, the inner returned vector contains the block ids
* used by the sequences on a certain depth (or "-1" if a sequence
* has fewer depth). The outer returned vector contains the inner
* vectors from the lowest depth to the highest depth.
*/
std::vector<std::vector<int32_t>> GetBlockIdsOnDepth(
const std::vector<Sequence*>& sequences) const {
std::pair<std::vector<std::vector<int32_t>>, std::vector<std::vector<int32_t>>>
GetBlockIdsOnDepth(const std::vector<Sequence*>& sequences) const {
// - Get the trace of each sequence.
int64_t num_depths = 0;
std::vector<std::vector<int32_t>> seq_block_traces;
std::vector<std::vector<int32_t>> trailing_block_traces;
seq_block_traces.reserve(cur_batch_size_);
trailing_block_traces.reserve(cur_batch_size_);
for (int i = 0; i < cur_batch_size_; ++i) {
std::vector<int32_t> trace = sequences[i]->GetBlockTrace(global_block_pool_);
num_depths = std::max(num_depths, static_cast<int64_t>(trace.size()));
seq_block_traces.push_back(std::move(trace));
if (static_cast<int>(trace.size()) <= kPagedKVCacheMaxBlockDepth) {
seq_block_traces.push_back(std::vector<int32_t>(trace.begin(), trace.end()));
trailing_block_traces.push_back({});
num_depths = std::max(num_depths, static_cast<int64_t>(trace.size()));
} else {
seq_block_traces.push_back(
std::vector<int32_t>(trace.begin(), trace.begin() + kPagedKVCacheMaxBlockDepth));
trailing_block_traces.push_back(
std::vector<int32_t>(trace.begin() + kPagedKVCacheMaxBlockDepth, trace.end()));
num_depths = std::max(num_depths, static_cast<int64_t>(kPagedKVCacheMaxBlockDepth));
}
}

// "Transpose" the traces, yielding the block ids used on each depth.
Expand All @@ -2071,7 +2118,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
block_ids_on_depths.push_back(std::move(block_ids));
}
return block_ids_on_depths;
return {block_ids_on_depths, trailing_block_traces};
}

/*!
Expand All @@ -2087,7 +2134,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
* input blocks.
*/
std::pair<std::vector<std::pair<int32_t, int32_t>>, bool> GetChunkedBlockIds(
const std::vector<int32_t>& block_ids) const {
const std::vector<int32_t>& block_ids, bool enable_coalesce = true) const {
std::vector<std::pair<int32_t, int32_t>> uncoalesced_block_ids;
std::vector<std::pair<int32_t, int32_t>> coalesced_block_ids;

Expand Down Expand Up @@ -2121,8 +2168,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced;
// Do not coalesce and use batch decode kernel when coalesce ratio is small.
bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 1.1;

return {use_decode_kernel ? uncoalesced_block_ids : coalesced_block_ids, use_decode_kernel};
return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids,
use_decode_kernel};
}

/*! \brief Invoke the "begin forward" functions of underlying kernels. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,7 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
apply_attention(kv_cache, rope_mode, [((5, 0, -1), 20)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((6, 5, -1), 102)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((7, 0, -1), 3)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((9, 5, -1), 20)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((8, 5, -1), 71), ((9, 5, -1), 20)], cached_k, cached_v)
# 0 <- 5 <- 6,8,9
# 0 <- 7
# 3 <- 4
Expand All @@ -637,15 +636,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)

apply_attention(kv_cache, rope_mode, [((10, 1, 33), 11)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((12, 0, 15), 14)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((14, 0, 17), 19)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((16, 5, 80), 10)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((17, 5, 75), 11)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((18, 5, 76), 45)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((19, 5, 77), 14)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((11, 0, 60), 45), ((12, 0, 15), 14)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((13, 0, 16), 19), ((14, 0, 17), 19)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((15, 5, 60), 8), ((16, 5, 80), 10)], cached_k, cached_v)
apply_attention(
kv_cache,
rope_mode,
[((17, 5, 75), 11), ((18, 5, 76), 45), ((19, 5, 77), 14)],
cached_k,
cached_v,
)

operation_seq = [
[(6, 1), (11, 1), (13, 1), (9, 1)],
Expand Down
Loading

0 comments on commit 675a023

Please sign in to comment.