From 805bc043e7fe21dbcce2655d5436146b71c3f3bf Mon Sep 17 00:00:00 2001 From: saki-daisuki Date: Mon, 29 Sep 2025 11:20:57 +0800 Subject: [PATCH 1/4] Update multistep.py --- ucm/ucm_sparse/kvstar/multistep.py | 344 +++++++++++++++++------------ 1 file changed, 197 insertions(+), 147 deletions(-) diff --git a/ucm/ucm_sparse/kvstar/multistep.py b/ucm/ucm_sparse/kvstar/multistep.py index 8aeeec49..3e9145f1 100644 --- a/ucm/ucm_sparse/kvstar/multistep.py +++ b/ucm/ucm_sparse/kvstar/multistep.py @@ -18,7 +18,7 @@ UcmSparseRole, ) from ucm.store.ucmstore import Task, UcmKVStoreBase -from ucm.ucm_sparse.kvstar.utils import bind_cpus, get_offset, block_hash_func +from ucm.ucm_sparse.kvstar.utils import bind_cpus, block_hash_func, get_offset """ -------------------------------------------------------------------------------------- @@ -34,15 +34,18 @@ ReqType = Union[str, int] # req_id的标识, 可以是str(UUID)或int(唯一), 和vllm保持一致 HashType = Union[str, int] # 使用hashtype方便阅读, 快速确认某些管理dict以hash为key + class ReqStage(enum.Enum): PREFILL = enum.auto() DECODE = enum.auto() + # NOTE: 预留检索任务状态枚举, TODO: 支持异步检索逻辑 class RetrieveTaskStatus(enum.Enum): WAITING = enum.auto() FINISHED = enum.auto() + # NOTE: 预留异步检索任务python侧管理结构, TODO: 待根据实际需求确认 @dataclass class RetrieveManager: @@ -50,6 +53,7 @@ class RetrieveManager: request_ids: List[ReqType] retrieve_tasks: dict # task_id/request_id, task_status + # 请求级的spare meta信息 @dataclass class ReqMeta: @@ -102,8 +106,8 @@ def stage(self) -> ReqStage: @property def is_last_chunk(self) -> bool: return ( - self.num_computed_tokens + self.num_scheduled_tokens - >= self.num_prompt_tokens + self.num_computed_tokens + self.num_scheduled_tokens + >= self.num_prompt_tokens ) @property @@ -130,23 +134,23 @@ def prefill_fully_blk_num(self) -> int: def query_offload_info(self) -> list | None: if self.stage == ReqStage.PREFILL: cur_step_parse_prompt_len_end_pos = ( - self.num_computed_tokens + self.num_scheduled_tokens + self.num_computed_tokens + self.num_scheduled_tokens ) if ( - cur_step_parse_prompt_len_end_pos - < self.num_prompt_tokens - self.retrieval_stride + cur_step_parse_prompt_len_end_pos + < self.num_prompt_tokens - self.retrieval_stride ): return None # 计算应该卸载到standby_group的哪些位置 valid_token_end_pos_in_retrieve_group = self.retrieval_stride - ( - self.num_prompt_tokens - cur_step_parse_prompt_len_end_pos + self.num_prompt_tokens - cur_step_parse_prompt_len_end_pos ) valid_token_num_in_retrieve_group = min( valid_token_end_pos_in_retrieve_group, self.num_scheduled_tokens ) valid_token_start_pos_in_retrieve_group = ( - valid_token_end_pos_in_retrieve_group - - valid_token_num_in_retrieve_group + valid_token_end_pos_in_retrieve_group + - valid_token_num_in_retrieve_group ) return list( range( @@ -169,20 +173,20 @@ def __init__(self): self.finished_req_ids = [] def add_request( - self, - request_id: ReqType, - index_in_batch: int, - num_prompt_tokens: int, - num_output_tokens: int, - num_scheduled_tokens: int, - num_computed_tokens: int, - num_sparsed_tokens: int, - vllm_block_ids: list[int], - token_blk_size, - query_start_loc:int, - query_len: int, - retrieval_stride: int, - prompt_token_ids: list[int], + self, + request_id: ReqType, + index_in_batch: int, + num_prompt_tokens: int, + num_output_tokens: int, + num_scheduled_tokens: int, + num_computed_tokens: int, + num_sparsed_tokens: int, + vllm_block_ids: list[int], + token_blk_size, + query_start_loc: int, + query_len: int, + retrieval_stride: int, + prompt_token_ids: list[int], ) -> None: meta = ReqMeta( request_id=request_id, @@ -216,23 +220,21 @@ class ReqPerLayerState: # 命名风格和vllm保持一致 """ def __init__( - self, - req_meta: ReqMeta, - layer_name: str, - rank: int, - tp_size: int, - store_instance: UcmKVStoreBase, - store_name: str, - sparse_cfg + self, + req_meta: ReqMeta, + layer_name: str, + rank: int, + tp_size: int, + store_instance: UcmKVStoreBase, + store_name: str, + sparse_cfg, ): # TODO: 后续若需要req_id, 作为属性添加 self.sparse_cfg = sparse_cfg self.layer_name = layer_name self.layer_id = int(layer_name.split(".")[2]) - self.blk_repre = ( - torch.Tensor() - ) + self.blk_repre = torch.Tensor() self.block_hashes = [] self.num_tokens = 0 # the number of all_tokens, prompt+output @@ -294,10 +296,12 @@ def retrieval_async(self, cur_step: int, topk: int, retrieve_device="cpu"): self.step_group_retrieve_result[retrieve_record] = [] return - self.do_retrieve_query_group[retrieve_record] = (torch.stack(self.standby_query_group[retrieve_record]) - .to(torch.float16) - .contiguous() - .to("cpu")) + self.do_retrieve_query_group[retrieve_record] = ( + torch.stack(self.standby_query_group[retrieve_record]) + .to(torch.float16) + .contiguous() + .to("cpu") + ) task_id = kvstar_retrieve.AsyncRetrieveByCPU( self.do_retrieve_query_group[retrieve_record], self.blk_repre, @@ -316,7 +320,9 @@ def get_retrieve_record(self, cur_step): if cur_step == 1: retrieve_record = "prefill" else: - retrieve_record = "decode" + str(cur_step - self.sparse_cfg["retrieval_stride"]) + retrieve_record = "decode" + str( + cur_step - self.sparse_cfg["retrieval_stride"] + ) return retrieve_record def extract_block_repre(self, vllm_block_ids, prune_dim_enable=False): @@ -349,7 +355,7 @@ def extract_block_repre(self, vllm_block_ids, prune_dim_enable=False): k_cache_prune[:, :, i_h, :] = k_cache[:, :, i_h, d_pruned_index[i_h]] self.d_pruned_index = d_pruned_index.contiguous().to("cpu") elif ( - self.d_pruned_index is not None + self.d_pruned_index is not None ): # decode 单块 dump时刷新decode块表征, 不参考前面所有完整块, 仅依据prefill获知的通道直接做裁剪 NOTE: 目前不做decode稀疏化, 外层走不到 h, d_pruned = self.d_pruned_index.shape d_pruned_index = self.d_pruned_index @@ -364,7 +370,9 @@ def extract_block_repre(self, vllm_block_ids, prune_dim_enable=False): c = self.sparse_cfg["blk_repre_inner_token_merge"] M = S // c - k_cache_new = k_cache_prune.reshape(n, M, c, h, d_pruned).mean(dim=2) # nMchd -> nMhd + k_cache_new = k_cache_prune.reshape(n, M, c, h, d_pruned).mean( + dim=2 + ) # nMchd -> nMhd return k_cache_new @@ -376,8 +384,8 @@ def prepare_init_and_local_window(self): if self.local_window is None: return - self.k_cache[vllm_block_ids[-self.local_window_sz:]] = self.local_window[0] - self.v_cache[vllm_block_ids[-self.local_window_sz:]] = self.local_window[1] + self.k_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[0] + self.v_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[1] def construct_init_and_local_window(self): vllm_block_ids = self.req_meta.vllm_block_ids @@ -387,7 +395,7 @@ def construct_init_and_local_window(self): self.v_cache[vllm_block_ids[: self.init_window_sz]].clone(), ) local_window_sz = min( - self.local_window_sz, len(vllm_block_ids[self.init_window_sz:]) + self.local_window_sz, len(vllm_block_ids[self.init_window_sz :]) ) if local_window_sz > 0: self.local_window = ( @@ -397,11 +405,11 @@ def construct_init_and_local_window(self): # NOTE: per_req, layerwise级别的attention_begin/attention_finished, 被UCMSparse级别(batch reqs)的同名函数内部按条件调用 def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - forward_context: ForwardContext, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + forward_context: ForwardContext, ) -> None: # -------------------------卸载query--------------------------------- # 1. 先获取该req的query长度 @@ -416,10 +424,12 @@ def attention_begin( if self.blk_repre is None: return assert ( - query_len == 1 + query_len == 1 ), "KVStar series sparse attention doesn't support spec_decode now" group_record, step_idx_in_retrieve_group = self.get_decode_step_record() - self.save_to_standby(group_record, step_idx_in_retrieve_group, query_start_loc, query) + self.save_to_standby( + group_record, step_idx_in_retrieve_group, query_start_loc, query + ) if self.req_meta.step % self.sparse_cfg["retrieval_stride"] == 0: candidate_swap_vllm_block_ids = self.get_retrieve_candidate_block_ids() @@ -427,15 +437,21 @@ def attention_begin( # 对于step 1, 下发并等待prefill last 8token检索 # 对于step 9, 下发step1~8检索任务, 等待prefill last 8token检索 # 对于step 17, 下发step9~16检索任务, 等待step1~8检索任务 - self.retrieval_async(self.req_meta.step + 1, len(candidate_swap_vllm_block_ids)) # 异步逻辑 + self.retrieval_async( + self.req_meta.step + 1, len(candidate_swap_vllm_block_ids) + ) # 异步逻辑 # self.retrieval_sync(self.req_meta.step, len(candidate_swap_vllm_block_ids)) if self.req_meta.step == 1: self.prepare_init_and_local_window() # step1 特殊操作,需要等待检索任务完成后,串行执行加载,并等待加载完成。 candidate_swap_vllm_block_ids = self.get_retrieve_candidate_block_ids() self.wait_for_blk_transfer_task_done() - self.retrieval_async(self.req_meta.step, len(candidate_swap_vllm_block_ids)) # 异步逻辑 - self.load_retrieve_result_async(self.req_meta.step, candidate_swap_vllm_block_ids) + self.retrieval_async( + self.req_meta.step, len(candidate_swap_vllm_block_ids) + ) # 异步逻辑 + self.load_retrieve_result_async( + self.req_meta.step, candidate_swap_vllm_block_ids + ) if self.req_meta.step % self.sparse_cfg["retrieval_stride"] == 1: # 需要等待检索cache加载完成 self.wait_for_blk_transfer_task_done() @@ -450,13 +466,22 @@ def offload_prefill_query(self, query, query_len, query_start_loc): offload_query_len = len(chunk_prefill_query_offload_info) # 3. 裁剪需要offload的query assert query_len >= offload_query_len - tokens_to_offload = query[query_start_loc + query_len - offload_query_len: - query_start_loc + query_len] + tokens_to_offload = query[ + query_start_loc + + query_len + - offload_query_len : query_start_loc + + query_len + ] group_record = "prefill" for query_relative_idx, in_query_group_idx in enumerate( - chunk_prefill_query_offload_info + chunk_prefill_query_offload_info ): - self.save_to_standby(group_record, in_query_group_idx, query_relative_idx, tokens_to_offload) + self.save_to_standby( + group_record, + in_query_group_idx, + query_relative_idx, + tokens_to_offload, + ) def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): if load_step <= self.sparse_cfg["retrieval_stride"] * 2: @@ -465,9 +490,9 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): cur_group_idx = int( math.ceil(load_step / self.sparse_cfg["retrieval_stride"]) ) # e.g. step 17 / 8 = 第3组 - wait_retrieve_step_idx = ( - cur_group_idx - 3 - ) * self.sparse_cfg["retrieval_stride"] + 1 + wait_retrieve_step_idx = (cur_group_idx - 3) * self.sparse_cfg[ + "retrieval_stride" + ] + 1 need_retrieve_record = "decode" + str(wait_retrieve_step_idx) if self.step_group_retrieve_result.get(need_retrieve_record) is None: async_retrieve_task_id = self.task_waiter[need_retrieve_record] @@ -481,7 +506,7 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): topk_indices = task_result["data"] # KVSTAR_RETRIEVE init_window_sz = self.sparse_cfg["init_window_sz"] select_blk_hashes = [ - self.block_hashes[int(id_) + init_window_sz] for id_ in topk_indices + self.block_hashes[int(id_) + init_window_sz] for id_ in topk_indices ] self.step_group_retrieve_result[need_retrieve_record] = ( select_blk_hashes @@ -497,7 +522,7 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): # -------------------------触发块异步加载--------------------------------- # 第一个迭代步取完prefill的检索结果后,被头两组decode复用,第三组才开始取之后的块 - if (need_retrieve_record != "prefill" or load_step == 1): + if need_retrieve_record != "prefill" or load_step == 1: if len(retrieve_result_hash_list) > 0: self.launch_transfer_task( "load", retrieve_result_hash_list, candidate_swap_vllm_block_ids @@ -506,21 +531,31 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids): def get_retrieve_candidate_block_ids(self): candidate_swap_vllm_block_ids = self.req_meta.vllm_block_ids[ - self.init_window_sz: - math.ceil(self.blk_repre.shape[0] * self.sparse_cfg["sparse_ratio"]) + self.init_window_sz - ] + self.init_window_sz : math.ceil( + self.blk_repre.shape[0] * self.sparse_cfg["sparse_ratio"] + ) + + self.init_window_sz + ] return candidate_swap_vllm_block_ids def get_decode_step_record(self): cur_decode_step = self.req_meta.step - step_idx_in_retrieve_group = (cur_decode_step - 1) % self.sparse_cfg["retrieval_stride"] - belong_retrieve_group = ((cur_decode_step - 1) // self.sparse_cfg["retrieval_stride"]) * self.sparse_cfg["retrieval_stride"] + 1 + step_idx_in_retrieve_group = (cur_decode_step - 1) % self.sparse_cfg[ + "retrieval_stride" + ] + belong_retrieve_group = ( + (cur_decode_step - 1) // self.sparse_cfg["retrieval_stride"] + ) * self.sparse_cfg["retrieval_stride"] + 1 group_record = "decode" + str(belong_retrieve_group) return group_record, step_idx_in_retrieve_group - def save_to_standby(self, group_record, in_query_group_idx, query_relative_idx, tokens_to_offload): + def save_to_standby( + self, group_record, in_query_group_idx, query_relative_idx, tokens_to_offload + ): if group_record not in self.standby_query_group.keys(): - self.standby_query_group[group_record] = [None] * self.sparse_cfg["retrieval_stride"] + self.standby_query_group[group_record] = [None] * self.sparse_cfg[ + "retrieval_stride" + ] self.standby_query_group[group_record][in_query_group_idx] = tokens_to_offload[ query_relative_idx ].clone() @@ -528,7 +563,10 @@ def save_to_standby(self, group_record, in_query_group_idx, query_relative_idx, def compute_block_repre(self, num_blocks_need_dump): if self.req_meta.stage == ReqStage.PREFILL and self.req_meta.is_last_chunk: self.blk_repre = self.extract_block_repre( - self.req_meta.vllm_block_ids[:self.num_blocks_dumped + num_blocks_need_dump], prune_dim_enable=True + self.req_meta.vllm_block_ids[ + : self.num_blocks_dumped + num_blocks_need_dump + ], + prune_dim_enable=True, ) # NOTE: 关键, 维度剔除首尾块 if self.blk_repre is not None: @@ -536,7 +574,7 @@ def compute_block_repre(self, num_blocks_need_dump): self.blk_repre = None # NOTE: 小于保留窗口, 无需记录块表征 else: self.blk_repre = ( - self.blk_repre[self.init_window_sz: -self.local_window_sz] + self.blk_repre[self.init_window_sz : -self.local_window_sz] .to(torch.float16) .contiguous() .to("cpu") @@ -544,23 +582,30 @@ def compute_block_repre(self, num_blocks_need_dump): self.construct_init_and_local_window() def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - forward_context: ForwardContext, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_output: torch.Tensor, + forward_context: ForwardContext, ) -> None: if self.req_meta.stage != ReqStage.PREFILL: - if self.req_meta.step >= self.sparse_cfg["retrieval_stride"] * 2 and self.req_meta.step % self.sparse_cfg["retrieval_stride"] == 0: + if ( + self.req_meta.step >= self.sparse_cfg["retrieval_stride"] * 2 + and self.req_meta.step % self.sparse_cfg["retrieval_stride"] == 0 + ): # 在decode一组的最后一个迭代步完成attn计算时,启动异步load,此时旧cache已不再需要,可以换成下一组所需cache # decode头两组的KVCache在attn_begin时加载,此处只加载第三组开始的KVCache candidate_swap_vllm_block_ids = self.get_retrieve_candidate_block_ids() - self.load_retrieve_result_async(self.req_meta.step + 1, candidate_swap_vllm_block_ids) + self.load_retrieve_result_async( + self.req_meta.step + 1, candidate_swap_vllm_block_ids + ) return # 只在prefill阶段dump cache一次 self.maybe_register_kv_cache(forward_context) - num_tokens_updated = self.req_meta.num_computed_tokens + self.req_meta.num_scheduled_tokens + num_tokens_updated = ( + self.req_meta.num_computed_tokens + self.req_meta.num_scheduled_tokens + ) num_blocks_dumped = self.num_blocks_dumped num_full_blocks = num_tokens_updated // self.block_size # 截断取整获取满块 num_blocks_need_dump = num_full_blocks - num_blocks_dumped @@ -584,13 +629,13 @@ def maybe_register_kv_cache(self, forward_context: ForwardContext): @classmethod def blk_trans_task_hash( - cls, block_ids, store_type, tensor_type + cls, block_ids, store_type, tensor_type ): # 生成唯一标识块传输任务的hash return hash((tuple(block_ids), store_type, tensor_type)) @classmethod def req_state_hash( - cls, req_id, layer_name + cls, req_id, layer_name ): # 生成唯一标识req_layerwise state的hash return hash((req_id, layer_name)) @@ -609,27 +654,27 @@ def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): # 获取每个key或value在UCStore块内的偏移(UCStore块整合了TP域和全层) offsets_k = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, - self.layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length + get_offset( + block_shape, + self.local_tp_rank, + self.total_tp_size, + precision, + self.layer_id, + is_v=False, + is_mla=is_mla, + ) + ] * length offsets_v = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, - self.layer_id, - is_v=True, - is_mla=is_mla, - ) - ] * length + get_offset( + block_shape, + self.local_tp_rank, + self.total_tp_size, + precision, + self.layer_id, + is_v=True, + is_mla=is_mla, + ) + ] * length # vLLM block 位置 key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] @@ -646,7 +691,7 @@ def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): self.blk_trans_tasks[task_v_hash] = task_v def wait_for_blk_transfer_task_done( - self, + self, ): # 一些异步任务等待逻辑 NOTE: 注意区分检索任务和blk传输任务 for task_hash, task in self.blk_trans_tasks.items(): # TODO: handle exceptions here, refer to UcmKVConnector @@ -685,20 +730,24 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): localRankId=self.local_tp_rank, ) kvstar_retrieve.Setup(param) - self.connector_name = self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] + self.connector_name = ( + self._vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_connector_name" + ] + ) self.connector = get_kv_transfer_group().connector else: self.connector = None - #Note: 和ucm prefixcache block共用connector + # Note: 和ucm prefixcache block共用connector assert self._vllm_config.kv_transfer_config is not None # scheduler侧也记录config, 也许有用 - self.kvstar_multistep_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["KVStarMultiStep"] + self.kvstar_multistep_cfg = ( + vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["KVStarMultiStep"] + ) print(f"kvstar_multistep_cfg: {self.kvstar_multistep_cfg}") self.token_blk_size = vllm_config.cache_config.block_size @@ -714,8 +763,8 @@ def create_layerwise_req_state(self, req_meta, layer_name): if req_meta.request_id not in self.req_states: if self.req_states.get(req_meta.request_id) is None: self.req_states[req_meta.request_id] = [ - None - ] * self.total_num_hidden_layers + None + ] * self.total_num_hidden_layers if self.req_states[req_meta.request_id][layer_id] is None: self.req_states[req_meta.request_id][layer_id] = ReqPerLayerState( req_meta, @@ -724,7 +773,7 @@ def create_layerwise_req_state(self, req_meta, layer_name): self.total_tp_size, self.connector, self.connector_name, - self.kvstar_multistep_cfg + self.kvstar_multistep_cfg, ) return self.req_states[req_meta.request_id][layer_id] @@ -742,12 +791,12 @@ def request_finished_in_worker(self, request_id: ReqType): del self.req_states[request_id] def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, ) -> None: """ This is called at the beginning of "unified_attention". @@ -767,13 +816,13 @@ def attention_begin( req_layerwise_state.attention_begin(query, key, value, forward_context) def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_output: torch.Tensor, + layer_name: str, + forward_context: ForwardContext, ) -> None: """ This is called at the end of "unified_attention". @@ -794,11 +843,7 @@ def attention_finished( # ============================== def build_sparse_meta( - self, - scheduler_output, - requests, - input_batch, - attn_metadata + self, scheduler_output, requests, input_batch, attn_metadata ) -> None: # 函数内bind """ Build the sparse metadata for this step. @@ -822,10 +867,10 @@ def build_sparse_meta( query_lens = attn_metadata.query_lens for ( - req_id, - num_scheduled_tokens, + req_id, + num_scheduled_tokens, ) in ( - scheduler_output.num_scheduled_tokens.items() + scheduler_output.num_scheduled_tokens.items() ): # NOTE: num_scheduled_tokens包含投机token req_state = requests[req_id] if len(req_state.prompt_token_ids) > self.token_blk_size: @@ -839,12 +884,14 @@ def build_sparse_meta( scheduler_output.req_sparsed_slots[ req_id ], # 当前给定的slot预算 (num_sparsed_tokens) - req_state.block_ids[0], # 当前只支持单种kvcache group, tuple [0] 元素 + req_state.block_ids[ + 0 + ], # 当前只支持单种kvcache group, tuple [0] 元素 self.token_blk_size, query_start_locs[input_batch.req_id_to_index[req_id]].item(), query_lens[input_batch.req_id_to_index[req_id]].item(), self.kvstar_multistep_cfg["retrieval_stride"], - req_state.prompt_token_ids + req_state.prompt_token_ids, ) self._sparse_metadata = sparse_meta @@ -863,33 +910,36 @@ def estimate_num_slots_sparsed(self, request: Request) -> int: num_prefill_fully_block = request.num_prompt_tokens // block_size num_prefill_keep_fixed_blk = min( - self.kvstar_multistep_cfg["init_window_sz"] + self.kvstar_multistep_cfg["local_window_sz"], num_prefill_fully_block + self.kvstar_multistep_cfg["init_window_sz"] + + self.kvstar_multistep_cfg["local_window_sz"], + num_prefill_fully_block, ) num_sparse_saved_fully_blk = math.ceil( - (num_prefill_fully_block - num_prefill_keep_fixed_blk) * self.kvstar_multistep_cfg["sparse_ratio"] + (num_prefill_fully_block - num_prefill_keep_fixed_blk) + * self.kvstar_multistep_cfg["sparse_ratio"] ) # same as blk_repre.shape[0] * SPARSE_RATIO num_blocks_dense_total = math.ceil(request.num_tokens / block_size) # 向上取整 num_blocks_be_compressed_prefill = ( - num_prefill_fully_block - - num_sparse_saved_fully_blk - - num_prefill_keep_fixed_blk + num_prefill_fully_block + - num_sparse_saved_fully_blk + - num_prefill_keep_fixed_blk ) num_blocks_this_step_budget = ( - num_blocks_dense_total - num_blocks_be_compressed_prefill + num_blocks_dense_total - num_blocks_be_compressed_prefill ) tail_blk_valid_token_num = request.num_tokens % block_size if tail_blk_valid_token_num: estimate_num_slots_budget = ( - num_blocks_this_step_budget - 1 - ) * block_size + tail_blk_valid_token_num + num_blocks_this_step_budget - 1 + ) * block_size + tail_blk_valid_token_num else: estimate_num_slots_budget = ( - num_blocks_this_step_budget * block_size + num_blocks_this_step_budget * block_size ) # 接下来一步会满块, 触发block dump return estimate_num_slots_budget @@ -924,4 +974,4 @@ def allocate_slots( if num_blocks_to_allocate > block_pool.get_num_free_blocks(): return None coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) - return KVCacheBlocks(tuple([kept_blocks])) \ No newline at end of file + return KVCacheBlocks(tuple([kept_blocks])) From 4ff65bf36fb90d9da077061764caa63172f0706f Mon Sep 17 00:00:00 2001 From: saki-daisuki Date: Mon, 29 Sep 2025 11:21:15 +0800 Subject: [PATCH 2/4] Update utils.py --- ucm/ucm_sparse/kvstar/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ucm/ucm_sparse/kvstar/utils.py b/ucm/ucm_sparse/kvstar/utils.py index e23286cd..92f82b21 100644 --- a/ucm/ucm_sparse/kvstar/utils.py +++ b/ucm/ucm_sparse/kvstar/utils.py @@ -1,7 +1,8 @@ +import hashlib +import pickle import subprocess from functools import cache -import pickle -import hashlib + @cache def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: @@ -16,6 +17,7 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> v_offset = k_offset + k_min_data_block_size return v_offset if is_v else k_offset + @cache def md5(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) @@ -30,6 +32,7 @@ def block_hash_func(parent_block_hash, curr_block_token_ids): curr_block_token_ids_tuple = tuple(curr_block_token_ids) return md5((parent_block_hash, curr_block_token_ids_tuple)) + def execute_command(cmd_list): with subprocess.Popen( cmd_list, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE From e1fb9cf8fe2a03f20f343fb4996b2cc26b7f4d96 Mon Sep 17 00:00:00 2001 From: saki-daisuki Date: Mon, 29 Sep 2025 11:21:34 +0800 Subject: [PATCH 3/4] Update offline_inference_kvstar.py --- examples/offline_inference_kvstar.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_kvstar.py b/examples/offline_inference_kvstar.py index d05e07ac..69ff487a 100644 --- a/examples/offline_inference_kvstar.py +++ b/examples/offline_inference_kvstar.py @@ -23,6 +23,7 @@ def setup_environment_variables(): os.environ["PYTHONHASHSEED"] = "123456" os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" + @contextlib.contextmanager def build_llm_with_uc(module_path: str, name: str, model: str): ktc = KVTransferConfig( @@ -41,8 +42,8 @@ def build_llm_with_uc(module_path: str, name: str, model: str): "local_window_sz": 2, "sparse_ratio": 0.25, "retrieval_stride": 8, - "blk_repre_dim_prune_ratio": 0.25, # 块表征维度裁剪 - "blk_repre_inner_token_merge": 2 # 块内几个token融合成一个表征 + "blk_repre_dim_prune_ratio": 0.25, # 块表征维度裁剪 + "blk_repre_inner_token_merge": 2, # 块内几个token融合成一个表征 } }, }, @@ -162,8 +163,13 @@ def main(): sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=300) - print_output(llm, prompts_prefill_more_than_2_full_blk, sampling_params, "first") - print_output(llm, prompts_prefill_more_than_2_full_blk, sampling_params, "second") + print_output( + llm, prompts_prefill_more_than_2_full_blk, sampling_params, "first" + ) + print_output( + llm, prompts_prefill_more_than_2_full_blk, sampling_params, "second" + ) + if __name__ == "__main__": main() From 0f0576914d3cad2545eb528c79f473a8e8de2c23 Mon Sep 17 00:00:00 2001 From: saki-daisuki Date: Mon, 29 Sep 2025 11:24:50 +0800 Subject: [PATCH 4/4] Update factory.py --- ucm/integration/vllm/ucm_sparse/factory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ucm/integration/vllm/ucm_sparse/factory.py b/ucm/integration/vllm/ucm_sparse/factory.py index eab9cf23..d6cbbcdd 100644 --- a/ucm/integration/vllm/ucm_sparse/factory.py +++ b/ucm/integration/vllm/ucm_sparse/factory.py @@ -49,4 +49,6 @@ def create_sparse_method( "KvComp", "ucm.sandbox.sparse.kvcomp.kvcomp", "KvComp" ) UcmSparseFactory.register_sparse_method("GSA", "ucm.ucm_sparse.gsa", "GSA") -UcmSparseFactory.register_sparse_method("KVStarMultiStep", "ucm.ucm_sparse.kvstar.multistep", "KVStarMultiStep") +UcmSparseFactory.register_sparse_method( + "KVStarMultiStep", "ucm.ucm_sparse.kvstar.multistep", "KVStarMultiStep" +)