Skip to content

Commit

Permalink
[KVCache] Support KVCache decode from forked sequence and pop more to…
Browse files Browse the repository at this point in the history
…kens (#16995)
  • Loading branch information
cyx-6 committed May 20, 2024
1 parent 3cd6673 commit 18a2a25
Showing 1 changed file with 53 additions and 12 deletions.
65 changes: 53 additions & 12 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -925,10 +925,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
// Fork at last by appending a new block directly
int32_t parent_block_idx = parent_it->second.last_block_idx;
if (!global_block_pool_[parent_block_idx].seq_length) {
// If parent ends with empty block, fork from parent's parent block
parent_block_idx = global_block_pool_[parent_block_idx].parent_idx;
}
++global_block_pool_[parent_block_idx].external_ref_cnt;
// Update child block start position and parent index
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
if (global_block_pool_[parent_block_idx].seq_length) {
// If parent is not empty, append a new block
int32_t new_parent_block_idx = GetFreeBlock();
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
parent_it->second.last_block_idx = new_parent_block_idx;
}
} else {
// Locate the block to fork from and calculate in-block offset
std::vector<int32_t> trace = parent_it->second.GetBlockTrace(global_block_pool_);
Expand Down Expand Up @@ -1038,21 +1049,51 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
auto it = seq_map_.find(seq_id);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";

Block& block = global_block_pool_[it->second.last_block_idx];
CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative.";
CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length
<< " in the last block, while the length of pop is " << n
<< " which exceeds the last-block sequence length.";
CHECK_LE(n, it->second.seq_length)
<< "The sequence only has length " << it->second.seq_length
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
int32_t block_idx = it->second.last_block_idx;
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
if (n > global_block_pool_[block_idx].seq_length) {
n -= global_block_pool_[block_idx].seq_length;
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
free_page_ids_.push_back(page_id);
}
free_block_idx_.push_back(block_idx);
block_idx = global_block_pool_[block_idx].parent_idx;
it->second.last_block_idx = block_idx;
continue;
}
if (n <= global_block_pool_[block_idx].seq_length) {
int64_t cur_npage = global_block_pool_[block_idx].page_ids.size();
int64_t tgt_npage =
(global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_;
while (cur_npage > tgt_npage) {
free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back());
global_block_pool_[block_idx].page_ids.pop_back();
--cur_npage;
}
it->second.seq_length -= n;
global_block_pool_[block_idx].seq_length -= n;
n = 0;
break;
}
}

int64_t cur_npage = block.page_ids.size();
int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_;
while (cur_npage > tgt_npage) {
free_page_ids_.push_back(block.page_ids.back());
block.page_ids.pop_back();
--cur_npage;
if (n) {
int32_t temp_seq_id = -1 - seq_id;
CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
RemoveSequence(seq_id);
CHECK(seq_map_.find(seq_id) == seq_map_.end());
auto it = seq_map_.find(temp_seq_id);
seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
seq_map_.erase(temp_seq_id);
}
it->second.seq_length -= n;
block.seq_length -= n;

dirty_aux_data_device_ = true;
}

Expand Down

0 comments on commit 18a2a25

Please sign in to comment.