From 6755e314d912e256ff2e14fdf44a643c8a8eee24 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 16:19:38 +0800 Subject: [PATCH 1/8] adapt nfsstore --- ucm/sparse/gsa/gsa.py | 307 +++++++++++++----- .../gsa/offload_ops/src/select_topk_block.cpp | 7 +- ucm/sparse/gsa/prefetch/include/kvcache_pre.h | 4 +- ucm/sparse/gsa/prefetch/prefetch_engine.py | 22 +- ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 95 +++--- ucm/sparse/gsa/prefetch/src/pybinds.cpp | 2 +- 6 files changed, 301 insertions(+), 136 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index a0b430f0..29e5c0b1 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -22,6 +22,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -42,7 +43,7 @@ class GSAReqStat: - def __init__(self, req_id, block_size) -> None: + def __init__(self, req_id, vllm_config: VllmConfig) -> None: self.req_id = req_id self.repre_slot_mapping = [] self.calc_block_table = [] @@ -62,11 +63,15 @@ def __init__(self, req_id, block_size) -> None: self.init_window_kv = None self.local_window_kv = [] self.sparse_len = 0 - self.block_size = block_size + self.block_size = vllm_config.cache_config.block_size self.block_hashes = None self.num_prompt_blocks = 0 self.reamin_map = None self.prefetch_map = None + self._vllm_config = vllm_config + self.rank = vllm_config.parallel_config.rank + self.use_mla = vllm_config.model_config.use_mla + self.request_hasher = RequestHasher(vllm_config, 0) def step(self) -> int: return self.num_output_tokens @@ -92,23 +97,36 @@ def is_last_chunk(self) -> bool: def get_seq_len(self) -> int: return self.num_computed_tokens + self.num_scheduled_tokens - + def set_block_hashes(self, token_ids): if self.block_hashes is not None: return self.block_hashes = [] - parent_block_hash_value = None + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + for start in range(0, len(token_ids), self.block_size): end = start + self.block_size block_token_ids = token_ids[start:end] if len(block_token_ids) < self.block_size: break curr_block_token_ids_tuple = tuple(block_token_ids) - block_hash = block_hash_func( - parent_block_hash_value, curr_block_token_ids_tuple + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) ) - self.block_hashes.append(str(block_hash)) - parent_block_hash_value = block_hash + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.is_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes): + self.block_hashes[i] = str( + self.newqrequest_hasher(ucm_block_id) + ) def add_req_new( self, num_scheduled_tokens, add_req_state, index_in_batch, offset @@ -240,11 +258,12 @@ def _update_slot( class GSAMetaData(UcmSparseMetadata): - def __init__(self, block_size, device, use_mla): + def __init__(self, vllm_config: VllmConfig): self.gsa_stats = {} - self.block_size = block_size - self.device = device - self.use_mla = use_mla + self.block_size = vllm_config.cache_config.block_size + self.device = vllm_config.device_config.device_type + self.use_mla = vllm_config.model_config.use_mla + self._vllm_config = vllm_config def get_model_input( self, @@ -260,7 +279,7 @@ def get_model_input( if scheduler_output.scheduled_cached_reqs.resumed_from_preemption[index]: del self.gsa_stats[req_id] prefetch_engine.del_finish_meta(req_id, False) - self.gsa_stats[req_id] = GSAReqStat(req_id, self.block_size) + self.gsa_stats[req_id] = GSAReqStat(req_id, self.block_size, self._vllm_config) self.gsa_stats[req_id].add_req_new( scheduler_output.num_scheduled_tokens[req_id], requests[req_id], @@ -276,7 +295,7 @@ def get_model_input( for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self.gsa_stats: del self.gsa_stats[new_req.req_id] - self.gsa_stats[new_req.req_id] = GSAReqStat(new_req.req_id, self.block_size) + self.gsa_stats[new_req.req_id] = GSAReqStat(new_req.req_id, self.block_size, self._vllm_config) self.gsa_stats[new_req.req_id].add_req_new( scheduler_output.num_scheduled_tokens[new_req.req_id], requests[new_req.req_id], @@ -293,20 +312,10 @@ def trans_input_tensor(self, scheduler_output: SchedulerOutput): query_locals = [0] * (batch_size + 1) if self.use_mla: query_locals_prefill = [0] * (batch_size + 1) - if CUDA_TOPK: - repre_slot_mapping = [0] * batch_size - include_mask = [0] * batch_size - exclude_mask = [0] * batch_size for req_id, num_tokens in scheduler_output.num_scheduled_tokens.items(): req_in_batch = self.gsa_stats[req_id].index_in_batch calc_block_table += self.gsa_stats[req_id].calc_block_table calc_repre_slot_mappings += self.gsa_stats[req_id].calc_repre_slot_mapping - if CUDA_TOPK: - repre_slot_mapping[req_in_batch] = self.gsa_stats[ - req_id - ].repre_slot_mapping - include_mask[req_in_batch] = self.gsa_stats[req_id].include_mask - exclude_mask[req_in_batch] = self.gsa_stats[req_id].exclude_mask query_locals[req_in_batch + 1] = scheduler_output.num_scheduled_tokens[ req_id ] @@ -321,16 +330,6 @@ def trans_input_tensor(self, scheduler_output: SchedulerOutput): model_input["calc_repre_slot_mapping"] = torch.tensor( calc_repre_slot_mappings, dtype=torch.int32, device="cpu" ) - if CUDA_TOPK: - model_input["repre_slot_mapping"] = make_tensor_with_pad( - repre_slot_mapping, pad=0, dtype=torch.int32, device=self.device - ) - model_input["include_mask"] = make_tensor_with_pad( - include_mask, pad=False, dtype=torch.uint8, device=self.device - ) - model_input["exclude_mask"] = make_tensor_with_pad( - exclude_mask, pad=False, dtype=torch.uint8, device=self.device - ) model_input["query_locals"] = query_locals if self.use_mla: model_input["query_locals_prefill"] = query_locals_prefill @@ -406,11 +405,8 @@ def set_topk_caches(self, cal_topk_id, topk_caches, topk_len_list): def cal_topk(self, intermediate_q, current_layer_id): bs = len(self.cal_topk_id) head_group_num = self.att_num_heads // self.kv_num_heads - if self.use_mla: - q_decode = intermediate_q - else: - q_decode = intermediate_q[self.cal_topk_id] - kpre_index = self.repre_slot_mapping[self.cal_topk_id].flatten() + q_decode = intermediate_q[self.cal_topk_id] + kpre_index = self.repre_slot_mapping.flatten() kpre_need = self.kpre_caches[current_layer_id][kpre_index] max_norm_num = kpre_need.shape[1] kpre_out = kpre_need.unsqueeze(2).expand(-1, -1, head_group_num, -1, -1) @@ -422,10 +418,10 @@ def cal_topk(self, intermediate_q, current_layer_id): ) dot_product_weights = attention_weights_without_norm.mean(1) dot_product_weights.masked_fill_( - self.include_mask[self.cal_topk_id] == 1, float("inf") + self.include_mask == 1, float("inf") ) dot_product_weights.masked_fill_( - self.exclude_mask[self.cal_topk_id] == 1, float("-inf") + self.exclude_mask == 1, float("-inf") ) selected_block_nums = self.topk_len_list[0] _, top_indices = torch.topk( @@ -433,6 +429,47 @@ def cal_topk(self, intermediate_q, current_layer_id): ) self.topk_caches[current_layer_id][self.cal_topk_id] = top_indices +@cache +def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: + block_size, num_key_heads_per_tp, head_size = block_shape + k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision + v_min_data_block_size = k_min_data_block_size if not is_mla else 0 + layer_size = (k_min_data_block_size + v_min_data_block_size) * ( + tp_size if not is_mla else 1 + ) + if is_mla: + k_offset = layer_size * layer_id + else: + k_offset = layer_size * layer_id + layer_size // tp_size * rank + v_offset = k_offset + k_min_data_block_size + return v_offset if is_v else k_offset + +@cache +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") + +@cache +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset + +def task_hash_func(block_ids, store_type, tensor_type): + return hash((tuple(block_ids), store_type, tensor_type)) class GSA(UcmSparseBase): def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): @@ -463,13 +500,14 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.connector = get_kv_transfer_group().connector else: self.connector = None + self.is_python_load = torch.npu.is_available() if CUDA_TOPK: self.prefetch_engine = GSAPrefetchBase( - vllm_config, 16, True, False, False, 1 + vllm_config, 16, True, False, False, 1, self.is_python_load ) else: self.prefetch_engine = GSAPrefetchBase( - vllm_config, 16, True, True, False, 1 + vllm_config, 16, True, True, False, 1, self.is_python_load ) self.topk_kpre_manger = TopKAndKpreManger(MAX_BS) self.gsa_metadata = None @@ -479,6 +517,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.decode_index = [] self.copy_k_flag = [False] * self.layer_num gsa_config.set_config(self.block_size) + self.task_load = {} def init_topk_cal( self, @@ -510,7 +549,7 @@ def init_topk_cal( ) if CUDA_TOPK: self.gsa_cuda_topk = TopkCal( - att_num_heads, kv_num_heads, head_size, prefetch_engine.kpre_caches + att_num_heads, kv_num_heads, head_size, prefetch_engine.kpre_caches, self.use_mla ) def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None: @@ -562,7 +601,8 @@ def copy_k(self, layer_name: str, forward_context: ForwardContext) -> None: .kv_cache[forward_context.virtual_engine][block_ids] .mean(dim=1, keepdim=True) ) - key_cache_mean_out = torch.unsqueeze(key_cache_mean_out, 1) + if torch.cuda.is_available(): + key_cache_mean_out = torch.unsqueeze(key_cache_mean_out, 1) if CUDA_TOPK: self.prefetch_engine.kpre_caches[current_layer_id][ calc_repre_slot_mappings @@ -625,7 +665,7 @@ def attention_begin( current_layer_id ][self.decode_index] else: - attn_metadata.decode.block_tables[ + attn_metadata.decode.block_table[ : len(self.prefetch_engine.req_ids_bs) ].copy_( self.model_input["block_tables_mp"][current_layer_id][ @@ -652,7 +692,7 @@ def attention_finished( if not self.copy_k_flag[current_layer_id]: self.copy_k(layer_name, forward_context) self.copy_k_flag[current_layer_id] = True - if self.use_mla: + if self.use_mla and torch.cuda.is_available(): return for req_id in self.prefetch_engine.req_ids_bs: assert req_id in self.gsa_metadata.gsa_stats @@ -819,7 +859,7 @@ def build_gsa_metadata( if not self.topk_kpre_manger.is_exist(req_id): index = self.topk_kpre_manger.alloc(req_id) assert index != None - gsa_meta = GSAMetaData(self.block_size, self.device, self.use_mla) + gsa_meta = GSAMetaData(self._vllm_config) gsa_meta.gsa_stats = self.gsa_stats self.model_input = gsa_meta.get_model_input( scheduler_output, @@ -867,11 +907,91 @@ def execute_finished(self): layer_id = int(layer_name.split(".")[2]) kv_caches[layer_id] = kv_cache if PTOPK_PREFETCH_ENABLE: - self.prefetch_engine.deal_async_prefetch( - self.gsa_metadata, kv_caches, self.connector.cc_store() + if self.is_python_load: + is_prefetch_done = self.check_transfer_task_done() + else: + is_prefetch_done = self.prefetch_engine.prefetch_engine_c.get_prefetch_status() + all_free_block_ids, all_miss_ids = self.prefetch_engine.deal_async_prefetch( + is_prefetch_done, self.gsa_metadata, kv_caches, self.connector.cc_store() ) + if is_prefetch_done: + self.load_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) else: - self.prefetch_engine.deal_async_prefetch(self.gsa_metadata, kv_caches, None) + self.prefetch_engine.deal_async_prefetch(False, self.gsa_metadata, kv_caches, None) + + def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): + if all_free_block_ids == None: + return + fn = getattr(self.connector, 'load') + precision = self.element_size + if self.use_mla: + block_data_size = kv_caches[0].numel() * precision + else: + block_data_size = kv_caches[0][0].numel() * precision + + offsets_k = [] + key_src_tensors = [] + block_hashes = [] + + for req_id in all_free_block_ids.keys(): + req_block_hash = self.gsa_metadata.gsa_stats[req_id].block_hashes + for layer_id in range(self.layer_num): + length = len(all_free_block_ids[req_id][layer_id]) + if length == 0: + continue + + offset_k = compute_layer_offset( + block_data_size, + layer_id, + is_v=False, + is_mla=self.use_mla, + ) + offsets_k += [offset_k] * length + block_hashes += [ + req_block_hash[i] + for i in all_miss_ids[req_id][layer_id] + ] + + if not self.use_mla: + key_src_tensors += [ + kv_caches[layer_id][0][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + offset_v = compute_layer_offset( + block_data_size, + layer_id, + is_v=True, + is_mla=self.use_mla, + ) + offsets_k += [offset_v] * length + block_hashes += [ + req_block_hash[i] + for i in all_miss_ids[req_id][layer_id] + ] + key_src_tensors += [ + kv_caches[layer_id][1][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + else: + key_src_tensors += [ + kv_caches[layer_id][_id] + for _id in all_free_block_ids[req_id][layer_id] + ] + + task_all = fn(block_hashes, offsets_k, key_src_tensors) + task_all_hash = task_hash_func(block_hashes, "load", "value") + self.task_load[task_all_hash] = task_all + + def check_transfer_task_done(self) -> bool: + if len(self.task_load) == 0: + return True + + for task_hash, task in self.task_load.items(): + ret = self.connector.check(task) + if not ret: + return False + self.task_load.clear() + return True def build_sparse_meta( self, scheduler_output: SchedulerOutput, requests, input_batch, attn_metadata @@ -926,46 +1046,67 @@ def estimate_num_slots_sparsed(self, request: Request) -> int: return num_tokens_sparsed def _start_topk_cal(self) -> None: - cal_topk_id = [] - is_decode = [] - topk_len_list = [] - repre_slot_mappings = [] - calc_block_tables = [] - calc_repre_slot_mappings = [] - for req_id in self.prefetch_engine.req_ids_bs: - req_meta = self.gsa_metadata.gsa_stats[req_id] - if req_meta.is_gsa(): - cal_topk_id.append(req_meta.index_in_batch) - is_decode.append(True) - one_topk_len = ( - gsa_config.compute_topk_len(len(req_meta.blocks)) - + gsa_config.num_prefetch_blocks - ) - topk_len_list.append(one_topk_len) - else: - is_decode.append(False) - repre_slot_mappings.append(req_meta.repre_slot_mapping) - calc_block_tables = self.model_input["calc_block_table"] - calc_repre_slot_mappings += req_meta.calc_repre_slot_mapping - if CUDA_TOPK and len(topk_len_list) != 0: - topk_len_list = [max(topk_len_list)] * len(topk_len_list) - self.gsa_offload_ops.set_common_param(cal_topk_id, is_decode) - if len(calc_block_tables) != 0: - self.gsa_offload_ops.set_kpre_param( - calc_block_tables, calc_repre_slot_mappings - ) if self.prefetch_engine.atb_gsa_enable and self.prefetch_engine.is_topk_cal: + cal_topk_id = [] + is_decode = [] + topk_len_list = [] + repre_slot_mappings = [] + repre_slot_mappings_all = [] + include_masks = [] + exclude_masks = [] + for req_id in self.prefetch_engine.req_ids_bs: + req_meta = self.gsa_metadata.gsa_stats[req_id] + if req_meta.is_gsa(): + cal_topk_id.append(req_meta.index_in_batch) + is_decode.append(True) + one_topk_len = ( + gsa_config.compute_topk_len(len(req_meta.blocks)) + + gsa_config.num_prefetch_blocks + ) + topk_len_list.append(one_topk_len) + if CUDA_TOPK: + include_masks.append( + req_meta.include_mask + ) + exclude_masks.append( + req_meta.exclude_mask + ) + repre_slot_mappings.append( + req_meta.repre_slot_mapping + ) + + else: + is_decode.append(False) + repre_slot_mappings_all.append(req_meta.repre_slot_mapping) + + if CUDA_TOPK and len(topk_len_list) != 0: + topk_len_list = [max(topk_len_list)] * len(topk_len_list) + repre_slot_mappings = make_tensor_with_pad( + repre_slot_mappings, pad=0, dtype=torch.int32, device=self.device + ) + include_masks = make_tensor_with_pad( + include_masks, pad=False, dtype=torch.uint8, device=self.device + ) + exclude_masks = make_tensor_with_pad( + exclude_masks, pad=True, dtype=torch.uint8, device=self.device + ) + self.gsa_offload_ops.set_common_param(cal_topk_id, is_decode) + if len(self.model_input["calc_block_table"]) != 0: + self.gsa_offload_ops.set_kpre_param( + self.model_input["calc_block_table"], [] + ) + if CUDA_TOPK: self.gsa_cuda_topk.set_topk_param( - self.model_input["repre_slot_mapping"], - self.model_input["include_mask"], - self.model_input["exclude_mask"], + repre_slot_mappings, + include_masks, + exclude_masks, ) self.gsa_cuda_topk.set_topk_caches( cal_topk_id, self.model_input["topk_caches"], topk_len_list ) else: - self.gsa_offload_ops.set_topk_param(repre_slot_mappings) + self.gsa_offload_ops.set_topk_param(repre_slot_mappings_all) self.gsa_offload_ops.set_topk_cache( self.model_input["topk_caches"], topk_len_list ) diff --git a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp index e5e87ae2..fe7afa5a 100644 --- a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp +++ b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp @@ -29,6 +29,9 @@ void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32 for (uint32_t i = 0; i < startWindow_; ++i) { topkIndices[idx++] = i; } + for (uint32_t i = 0; i < endWindow_; ++i) { + topkIndices[idx++] = numScores - endWindow_ + i; + } int32_t midCount = k - startWindow_ - endWindow_; if (midCount > 0) { std::vector middleIndices; @@ -44,10 +47,6 @@ void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32 topkIndices[idx++] = middleIndices[i]; } } - for (uint32_t i = 0; i < endWindow_; ++i) { - topkIndices[idx++] = numScores - endWindow_ + i; - } - std::sort(topkIndices, topkIndices + k); } float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h index 262316cd..703cf0ec 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h @@ -105,6 +105,7 @@ namespace ucmprefetch std::map> mAllBlcoksHash; uint32_t mKVSzieBytes = 0; uint32_t mExtraTopkLen = 16; + bool mIsPythonLoad = false; public: std::mutex mMutex; bool mStopPrefetch = false; @@ -138,7 +139,8 @@ namespace ucmprefetch bool isLog, int tpSize, int rank, - int extraTopkLen + int extraTopkLen, + bool isPythonLoad ); void SetBlocksMap(std::string reqID, std::vector &blockTableList, diff --git a/ucm/sparse/gsa/prefetch/prefetch_engine.py b/ucm/sparse/gsa/prefetch/prefetch_engine.py index d31534d9..4da572ae 100644 --- a/ucm/sparse/gsa/prefetch/prefetch_engine.py +++ b/ucm/sparse/gsa/prefetch/prefetch_engine.py @@ -28,6 +28,7 @@ def __init__( is_cpu_topk: bool = False, is_max_norm: bool = False, max_norm_num: int = 1, + is_prefetch_done: bool = False, is_prefetch: Optional[bool] = True, head_num: Optional[int] = None, is_mutli_head: Optional[bool] = None, @@ -95,6 +96,7 @@ def __init__( self.tp_size, self.rank, gsa_config.num_prefetch_blocks, + is_prefetch_done ) self.topk_space = 0 @@ -157,9 +159,7 @@ def model_input_deal( block_table_tmp = self.use_block_table[:, block_table_index, :].to( self.device_config.device ) - gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index].to( - self.device_config.device - ) + gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index] list_topk_buf = list(topk_buf_tmp.unbind(dim=0)) list_block_table = list(block_table_tmp.unbind(dim=0)) @@ -197,12 +197,14 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp): ) self.topk_buf_tmp = topk_buf_tmp - def deal_async_prefetch(self, gsa_metadata, kvcache, store_ptr) -> None: + def deal_async_prefetch(self, is_prefetch_done, gsa_metadata, kvcache, store_ptr): + self.topk_space += 1 + all_free_block_ids = None + all_miss_ids = None if not self.atb_gsa_enable: - return - + return all_free_block_ids, all_miss_ids if ( - self.prefetch_engine_c.get_prefetch_status() + is_prefetch_done and self.ptopk_prefetch_enable and self.is_topk_update ): @@ -239,8 +241,10 @@ def deal_async_prefetch(self, gsa_metadata, kvcache, store_ptr) -> None: req_id_list, topk_len_list, self.select_bs_index, kvcache, store_ptr ) self.is_topk_update = False - else: - self.topk_space += 1 + if is_prefetch_done: + all_free_block_ids = self.prefetch_engine_c.obtain_load_blocks() + all_miss_ids = self.prefetch_engine_c.obtain_miss_idxs() + return all_free_block_ids, all_miss_ids def del_finish_meta(self, del_req, flag: bool = True) -> None: if del_req in self.block_map_flag: diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index b1e0be12..15b172f5 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -106,7 +106,8 @@ namespace ucmprefetch bool isLog, int tpSize, int rank, - int extraTopkLen + int extraTopkLen, + bool isPythonLoad ) :mLogger("./log/kvcache_pre_log.txt", LogLevel::INFO, isLog) { @@ -128,6 +129,7 @@ namespace ucmprefetch mBlockSize = kvShape[0]; mTPSize = tpSize; mRank = rank; + mIsPythonLoad = isPythonLoad; if(mRank != 0) { mLogger.SetLevel(LogLevel::WARNING); mIsLog = false; @@ -322,6 +324,9 @@ namespace ucmprefetch int blockID = mDocsTables[reqID][layerID][item]; hitBlocks.insert(blockID); hitBlocksIdx.insert(std::make_pair(item, blockID)); + if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { + break; + } } else { missIdxs.push_back(item); } @@ -329,7 +334,7 @@ namespace ucmprefetch oss << "------\n"; mLogger.log(LogLevel::DEBUG, oss.str().c_str()); oss.str(""); - if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen) { + if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen && hitBlocks.size() != (topkLen - mExtraTopkLen)) { mLogger.log(LogLevel::ERROR, "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: %lu, miss size: %lu , topkLen: %d, not equal error\n", mDecodeStep, mRank, reqID, layerID, hitBlocks.size(), missIdxs.size(), topkLen); @@ -367,6 +372,8 @@ namespace ucmprefetch oneFreeBlockIndex += 1; } } + uint32_t loadLen = oneFreeBlockTable.size(); + missIdxs.erase(missIdxs.begin() + loadLen, missIdxs.end()); allNeedLoadBlock[reqID][layerID] = oneFreeBlockTable; allMissIdxs[reqID][layerID] = missIdxs; LoadKVToHBM(oneFreeBlockTable, missIdxs, layerID, reqID); @@ -389,7 +396,7 @@ namespace ucmprefetch oneBsInfo.bsIndex = bsIndex; oneBsInfo.layerID = i; GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); - if (missIdxs.size() != 0 || hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { + if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); } int successIndex = 0; @@ -418,38 +425,40 @@ namespace ucmprefetch std::vector missIdxs, int layerID, std::string reqID) { for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { - if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { - mLogger.log(LogLevel::INFO, - "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", - mDecodeStep, mRank, reqID.c_str(), layerID); - return; - } - while (mStopPrefetch) { - std::this_thread::sleep_for(std::chrono::microseconds(2)); - } - UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; - std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; - size_t kOffset = GetOffset(layerID, false); - size_t vOffset = GetOffset(layerID, true); - if (!mUseMla) { - task.Append(blockId, kOffset, - reinterpret_cast(mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - task.Append(blockId, vOffset, - reinterpret_cast(mKvCaches[layerID][1][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - } else { - task.Append(blockId, kOffset, - reinterpret_cast(mKvCaches[layerID][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - } - size_t taskID = mStore->Submit(std::move(task)); - auto ret = mStore->Wait(taskID); - if (ret != 0) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, Rank: %d, reqID: %s, layer: %d, blockID: %lu, miss idx: %u, load blockid: %u load k error\n", - mDecodeStep, mRank, reqID.c_str(), layerID, blockId, missIdxs[i], loadNPUBlockIDs[i]); - return; + if (mIsPythonLoad) { + if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { + mLogger.log(LogLevel::INFO, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", + mDecodeStep, mRank, reqID.c_str(), layerID); + return; + } + while (mStopPrefetch) { + std::this_thread::sleep_for(std::chrono::microseconds(2)); + } + UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; + std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; + size_t kOffset = GetOffset(layerID, false); + size_t vOffset = GetOffset(layerID, true); + if (!mUseMla) { + task.Append(blockId, kOffset, + reinterpret_cast(mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + task.Append(blockId, vOffset, + reinterpret_cast(mKvCaches[layerID][1][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + } else { + task.Append(blockId, kOffset, + reinterpret_cast(mKvCaches[layerID][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + } + size_t taskID = mStore->Submit(std::move(task)); + auto ret = mStore->Wait(taskID); + if (ret != 0) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, blockID: %lu, miss idx: %u, load blockid: %u load k error\n", + mDecodeStep, mRank, reqID.c_str(), layerID, blockId, missIdxs[i], loadNPUBlockIDs[i]); + return; + } } int oriIdx = mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]]; @@ -472,10 +481,16 @@ namespace ucmprefetch } else { mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0][0].numel(); } + if (storePtr == nullptr) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, |KVCache Prefetch| storePtr is nullptr error\n", + mDecodeStep); + std::abort(); + } mStore = static_cast *>(storePtr); mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| start mKVSzieBytes: %u, mTensorElemSize %u\n", - mDecodeStep, mKVSzieBytes, mTensorElemSize); + "Decode step: %u, |KVCache Prefetch| start mKVSzieBytes: %u, mTensorElemSize %u, store %p\n", + mDecodeStep, mKVSzieBytes, mTensorElemSize, mStore); } mKvCaches = kvCaches; mLogger.log(LogLevel::INFO, @@ -495,7 +510,11 @@ namespace ucmprefetch mMutex.lock(); mIsPrefetchDone = false; mMutex.unlock(); - mThreadPool->enqueue(MutliBSThreadFun, this); + if (mIsPythonLoad) { + MutliBSThreadFun(this); + } else { + mThreadPool->enqueue(MutliBSThreadFun, this); + } } void GSAPrefetchEngineC::SetBlockTableInfo(torch::Tensor &blockTables, torch::Tensor &blockLengths, diff --git a/ucm/sparse/gsa/prefetch/src/pybinds.cpp b/ucm/sparse/gsa/prefetch/src/pybinds.cpp index 5c9391a9..decd3895 100644 --- a/ucm/sparse/gsa/prefetch/src/pybinds.cpp +++ b/ucm/sparse/gsa/prefetch/src/pybinds.cpp @@ -12,7 +12,7 @@ namespace ucmprefetch{ pybind11::class_(m, "GSAPrefetchEngineC") .def(pybind11::init &, - bool, bool, int, int, int>()) + bool, bool, int, int, int, bool>()) .def("set_blocks_map", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMap) .def("set_blocks_map_multilayer", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMapMultiLayer) .def("add_blocks_map", &ucmprefetch::GSAPrefetchEngineC::AddBlocksMap) From 93ba6bd857ada938454e80fc00389f8400db35f1 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 16:33:30 +0800 Subject: [PATCH 2/8] cleancode --- ucm/sparse/gsa/gsa.py | 58 +++++++++++----------- ucm/sparse/gsa/prefetch/prefetch_engine.py | 8 +-- 2 files changed, 30 insertions(+), 36 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index 29e5c0b1..7918d439 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -124,9 +124,7 @@ def set_block_hashes(self, token_ids): if self.rank != 0 and not self.is_mla: self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) for i, ucm_block_id in enumerate(self.block_hashes): - self.block_hashes[i] = str( - self.newqrequest_hasher(ucm_block_id) - ) + self.block_hashes[i] = str(self.newqrequest_hasher(ucm_block_id)) def add_req_new( self, num_scheduled_tokens, add_req_state, index_in_batch, offset @@ -295,7 +293,9 @@ def get_model_input( for new_req in scheduler_output.scheduled_new_reqs: if new_req.req_id in self.gsa_stats: del self.gsa_stats[new_req.req_id] - self.gsa_stats[new_req.req_id] = GSAReqStat(new_req.req_id, self.block_size, self._vllm_config) + self.gsa_stats[new_req.req_id] = GSAReqStat( + new_req.req_id, self.block_size, self._vllm_config + ) self.gsa_stats[new_req.req_id].add_req_new( scheduler_output.num_scheduled_tokens[new_req.req_id], requests[new_req.req_id], @@ -417,12 +417,8 @@ def cal_topk(self, intermediate_q, current_layer_id): qk.reshape(bs, self.att_num_heads, blk_num, max_norm_num), dim=-1 ) dot_product_weights = attention_weights_without_norm.mean(1) - dot_product_weights.masked_fill_( - self.include_mask == 1, float("inf") - ) - dot_product_weights.masked_fill_( - self.exclude_mask == 1, float("-inf") - ) + dot_product_weights.masked_fill_(self.include_mask == 1, float("inf")) + dot_product_weights.masked_fill_(self.exclude_mask == 1, float("-inf")) selected_block_nums = self.topk_len_list[0] _, top_indices = torch.topk( dot_product_weights, selected_block_nums, dim=-1, sorted=False @@ -549,7 +545,11 @@ def init_topk_cal( ) if CUDA_TOPK: self.gsa_cuda_topk = TopkCal( - att_num_heads, kv_num_heads, head_size, prefetch_engine.kpre_caches, self.use_mla + att_num_heads, + kv_num_heads, + head_size, + prefetch_engine.kpre_caches, + self.use_mla ) def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None: @@ -910,19 +910,26 @@ def execute_finished(self): if self.is_python_load: is_prefetch_done = self.check_transfer_task_done() else: - is_prefetch_done = self.prefetch_engine.prefetch_engine_c.get_prefetch_status() + is_prefetch_done = ( + self.prefetch_engine.prefetch_engine_c.get_prefetch_status() + ) all_free_block_ids, all_miss_ids = self.prefetch_engine.deal_async_prefetch( - is_prefetch_done, self.gsa_metadata, kv_caches, self.connector.cc_store() + is_prefetch_done, + self.gsa_metadata, + kv_caches, + self.connector.cc_store() ) if is_prefetch_done: - self.load_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) + self.launch_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) else: - self.prefetch_engine.deal_async_prefetch(False, self.gsa_metadata, kv_caches, None) + self.prefetch_engine.deal_async_prefetch( + False, self.gsa_metadata, kv_caches, None + ) def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): if all_free_block_ids == None: return - fn = getattr(self.connector, 'load') + fn = getattr(self.connector, "load") precision = self.element_size if self.use_mla: block_data_size = kv_caches[0].numel() * precision @@ -948,8 +955,7 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): ) offsets_k += [offset_k] * length block_hashes += [ - req_block_hash[i] - for i in all_miss_ids[req_id][layer_id] + req_block_hash[i] for i in all_miss_ids[req_id][layer_id] ] if not self.use_mla: @@ -965,8 +971,7 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): ) offsets_k += [offset_v] * length block_hashes += [ - req_block_hash[i] - for i in all_miss_ids[req_id][layer_id] + req_block_hash[i] for i in all_miss_ids[req_id][layer_id] ] key_src_tensors += [ kv_caches[layer_id][1][_id] @@ -1065,16 +1070,9 @@ def _start_topk_cal(self) -> None: ) topk_len_list.append(one_topk_len) if CUDA_TOPK: - include_masks.append( - req_meta.include_mask - ) - exclude_masks.append( - req_meta.exclude_mask - ) - repre_slot_mappings.append( - req_meta.repre_slot_mapping - ) - + include_masks.append(req_meta.include_mask) + exclude_masks.append(req_meta.exclude_mask) + repre_slot_mappings.append(req_meta.repre_slot_mapping) else: is_decode.append(False) repre_slot_mappings_all.append(req_meta.repre_slot_mapping) diff --git a/ucm/sparse/gsa/prefetch/prefetch_engine.py b/ucm/sparse/gsa/prefetch/prefetch_engine.py index 4da572ae..9b755014 100644 --- a/ucm/sparse/gsa/prefetch/prefetch_engine.py +++ b/ucm/sparse/gsa/prefetch/prefetch_engine.py @@ -96,7 +96,7 @@ def __init__( self.tp_size, self.rank, gsa_config.num_prefetch_blocks, - is_prefetch_done + is_prefetch_done, ) self.topk_space = 0 @@ -203,11 +203,7 @@ def deal_async_prefetch(self, is_prefetch_done, gsa_metadata, kvcache, store_ptr all_miss_ids = None if not self.atb_gsa_enable: return all_free_block_ids, all_miss_ids - if ( - is_prefetch_done - and self.ptopk_prefetch_enable - and self.is_topk_update - ): + if is_prefetch_done and self.ptopk_prefetch_enable and self.is_topk_update: tmp = self.use_block_table self.use_block_table = self.m_load_success_list self.m_load_success_list = tmp From 1248ad85472f0d50b5251605d52f7a33e935da31 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 16:39:30 +0800 Subject: [PATCH 3/8] cleancode --- ucm/sparse/gsa/gsa.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index 7918d439..64eeeba8 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -97,7 +97,7 @@ def is_last_chunk(self) -> bool: def get_seq_len(self) -> int: return self.num_computed_tokens + self.num_scheduled_tokens - + def set_block_hashes(self, token_ids): if self.block_hashes is not None: return @@ -277,7 +277,9 @@ def get_model_input( if scheduler_output.scheduled_cached_reqs.resumed_from_preemption[index]: del self.gsa_stats[req_id] prefetch_engine.del_finish_meta(req_id, False) - self.gsa_stats[req_id] = GSAReqStat(req_id, self.block_size, self._vllm_config) + self.gsa_stats[req_id] = GSAReqStat( + req_id, self.block_size, self._vllm_config + ) self.gsa_stats[req_id].add_req_new( scheduler_output.num_scheduled_tokens[req_id], requests[req_id], @@ -425,6 +427,7 @@ def cal_topk(self, intermediate_q, current_layer_id): ) self.topk_caches[current_layer_id][self.cal_topk_id] = top_indices + @cache def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: block_size, num_key_heads_per_tp, head_size = block_shape @@ -440,6 +443,7 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> v_offset = k_offset + k_min_data_block_size return v_offset if is_v else k_offset + @cache def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" @@ -447,6 +451,7 @@ def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() return int.from_bytes(h_seed, byteorder="big") + @cache def compute_layer_offset( block_data_size: int, @@ -464,9 +469,11 @@ def compute_layer_offset( v_offset = k_offset + block_data_size return v_offset if is_v else k_offset + def task_hash_func(block_ids, store_type, tensor_type): return hash((tuple(block_ids), store_type, tensor_type)) + class GSA(UcmSparseBase): def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): super().__init__(vllm_config, role) @@ -549,7 +556,7 @@ def init_topk_cal( kv_num_heads, head_size, prefetch_engine.kpre_caches, - self.use_mla + self.use_mla, ) def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None: @@ -917,7 +924,7 @@ def execute_finished(self): is_prefetch_done, self.gsa_metadata, kv_caches, - self.connector.cc_store() + self.connector.cc_store(), ) if is_prefetch_done: self.launch_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) @@ -935,7 +942,7 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): block_data_size = kv_caches[0].numel() * precision else: block_data_size = kv_caches[0][0].numel() * precision - + offsets_k = [] key_src_tensors = [] block_hashes = [] @@ -957,7 +964,7 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): block_hashes += [ req_block_hash[i] for i in all_miss_ids[req_id][layer_id] ] - + if not self.use_mla: key_src_tensors += [ kv_caches[layer_id][0][_id] @@ -986,11 +993,11 @@ def launch_transfer_task(self, all_free_block_ids, all_miss_ids, kv_caches): task_all = fn(block_hashes, offsets_k, key_src_tensors) task_all_hash = task_hash_func(block_hashes, "load", "value") self.task_load[task_all_hash] = task_all - + def check_transfer_task_done(self) -> bool: if len(self.task_load) == 0: return True - + for task_hash, task in self.task_load.items(): ret = self.connector.check(task) if not ret: From c38033fd3a71c821c191e5ad63f11a575d2e8c3b Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 19:55:23 +0800 Subject: [PATCH 4/8] md prefetch bug --- ucm/sparse/gsa/gsa.py | 12 +++++----- ucm/sparse/gsa/prefetch/include/kvcache_pre.h | 2 ++ ucm/sparse/gsa/prefetch/prefetch_engine.py | 14 +++++++---- ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 23 +++++++++++++++++-- 4 files changed, 39 insertions(+), 12 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index 64eeeba8..a4b0fc3b 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -121,7 +121,7 @@ def set_block_hashes(self, token_ids): ) parent_block_hash_value = hash_value - if self.rank != 0 and not self.is_mla: + if self.rank != 0 and not self.use_mla: self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) for i, ucm_block_id in enumerate(self.block_hashes): self.block_hashes[i] = str(self.newqrequest_hasher(ucm_block_id)) @@ -278,7 +278,7 @@ def get_model_input( del self.gsa_stats[req_id] prefetch_engine.del_finish_meta(req_id, False) self.gsa_stats[req_id] = GSAReqStat( - req_id, self.block_size, self._vllm_config + req_id, self._vllm_config ) self.gsa_stats[req_id].add_req_new( scheduler_output.num_scheduled_tokens[req_id], @@ -296,7 +296,7 @@ def get_model_input( if new_req.req_id in self.gsa_stats: del self.gsa_stats[new_req.req_id] self.gsa_stats[new_req.req_id] = GSAReqStat( - new_req.req_id, self.block_size, self._vllm_config + new_req.req_id, self._vllm_config ) self.gsa_stats[new_req.req_id].add_req_new( scheduler_output.num_scheduled_tokens[new_req.req_id], @@ -500,10 +500,10 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.dtype = vllm_config.model_config.dtype if PTOPK_PREFETCH_ENABLE: if role == UcmSparseRole.WORKER: - self.connector = get_kv_transfer_group().connector + self.connector = get_kv_transfer_group().connector.store else: self.connector = None - self.is_python_load = torch.npu.is_available() + self.is_python_load = not torch.cuda.is_available() if CUDA_TOPK: self.prefetch_engine = GSAPrefetchBase( vllm_config, 16, True, False, False, 1, self.is_python_load @@ -926,7 +926,7 @@ def execute_finished(self): kv_caches, self.connector.cc_store(), ) - if is_prefetch_done: + if self.is_python_load: self.launch_transfer_task(all_free_block_ids, all_miss_ids, kv_caches) else: self.prefetch_engine.deal_async_prefetch( diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h index 703cf0ec..d0389759 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h @@ -183,6 +183,8 @@ namespace ucmprefetch size_t GetOffset(uint32_t layerID, bool isV); + size_t GetOffsetNew(uint32_t layerID, bool isV); + std::map>> ObtainLoadBlocks(); std::map>> ObtainMissIdxs(); diff --git a/ucm/sparse/gsa/prefetch/prefetch_engine.py b/ucm/sparse/gsa/prefetch/prefetch_engine.py index 9b755014..c3832460 100644 --- a/ucm/sparse/gsa/prefetch/prefetch_engine.py +++ b/ucm/sparse/gsa/prefetch/prefetch_engine.py @@ -28,7 +28,7 @@ def __init__( is_cpu_topk: bool = False, is_max_norm: bool = False, max_norm_num: int = 1, - is_prefetch_done: bool = False, + is_python_load: bool = False, is_prefetch: Optional[bool] = True, head_num: Optional[int] = None, is_mutli_head: Optional[bool] = None, @@ -85,6 +85,7 @@ def __init__( ) self._init_tensor() kv_shape = [self.block_size, self.num_kv_heads, self.head_size] + self.is_python_load = is_python_load self.prefetch_engine_c = gsa_prefetch.GSAPrefetchEngineC( self.prefetch_blocks, self.m_load_success_list, @@ -96,7 +97,7 @@ def __init__( self.tp_size, self.rank, gsa_config.num_prefetch_blocks, - is_prefetch_done, + self.is_python_load, ) self.topk_space = 0 @@ -159,7 +160,12 @@ def model_input_deal( block_table_tmp = self.use_block_table[:, block_table_index, :].to( self.device_config.device ) - gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index] + if torch.cuda.is_available(): + gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index].to( + self.device_config.device + ) + else: + gen_len_tmp = self.gsa_seq_len[:, self.select_bs_index] list_topk_buf = list(topk_buf_tmp.unbind(dim=0)) list_block_table = list(block_table_tmp.unbind(dim=0)) @@ -237,7 +243,7 @@ def deal_async_prefetch(self, is_prefetch_done, gsa_metadata, kvcache, store_ptr req_id_list, topk_len_list, self.select_bs_index, kvcache, store_ptr ) self.is_topk_update = False - if is_prefetch_done: + if self.is_python_load: all_free_block_ids = self.prefetch_engine_c.obtain_load_blocks() all_miss_ids = self.prefetch_engine_c.obtain_miss_idxs() return all_free_block_ids, all_miss_ids diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index 15b172f5..c6aea95e 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -163,6 +163,25 @@ namespace ucmprefetch } } + size_t GSAPrefetchEngineC::GetOffsetNew(uint32_t layerID, bool isV) + { + size_t kMinDataBlockSize = static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; + size_t layerSize = kMinDataBlockSize * 2; + size_t kOffset = layerSize * layerID; + if (mUseMla) { + layerSize = kMinDataBlockSize; + kOffset = layerSize * layerID; + return kOffset; + } + size_t vOffset = kOffset + kMinDataBlockSize; + + if (isV) { + return vOffset; + } else { + return kOffset; + } + } + void GSAPrefetchEngineC::CheckInputIndex(uint32_t maxLen, uint32_t index) { if (index >= maxLen) { @@ -437,8 +456,8 @@ namespace ucmprefetch } UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; - size_t kOffset = GetOffset(layerID, false); - size_t vOffset = GetOffset(layerID, true); + size_t kOffset = GetOffsetNew(layerID, false); + size_t vOffset = GetOffsetNew(layerID, true); if (!mUseMla) { task.Append(blockId, kOffset, reinterpret_cast(mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), From 8d15a94275adce799dab08e2a390a33a64f5d716 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 20:47:44 +0800 Subject: [PATCH 5/8] md bug --- ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index c6aea95e..a6f61b88 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -444,7 +444,7 @@ namespace ucmprefetch std::vector missIdxs, int layerID, std::string reqID) { for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { - if (mIsPythonLoad) { + if (!mIsPythonLoad) { if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { mLogger.log(LogLevel::INFO, "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", From 642a286911fdf4adb53231f82a027e3cdd04a9a2 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 20:57:28 +0800 Subject: [PATCH 6/8] cleancode --- ucm/sparse/gsa/gsa.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ucm/sparse/gsa/gsa.py b/ucm/sparse/gsa/gsa.py index a4b0fc3b..b1bf1e5c 100644 --- a/ucm/sparse/gsa/gsa.py +++ b/ucm/sparse/gsa/gsa.py @@ -277,9 +277,7 @@ def get_model_input( if scheduler_output.scheduled_cached_reqs.resumed_from_preemption[index]: del self.gsa_stats[req_id] prefetch_engine.del_finish_meta(req_id, False) - self.gsa_stats[req_id] = GSAReqStat( - req_id, self._vllm_config - ) + self.gsa_stats[req_id] = GSAReqStat(req_id, self._vllm_config) self.gsa_stats[req_id].add_req_new( scheduler_output.num_scheduled_tokens[req_id], requests[req_id], From 0585ff66aac4157b5172f7def06942bfcecfe052 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 05:57:38 -0800 Subject: [PATCH 7/8] cleancode --- .../gsa/offload_ops/src/select_topk_block.cpp | 87 +- ucm/sparse/gsa/prefetch/include/kvcache_log.h | 88 +- ucm/sparse/gsa/prefetch/include/kvcache_pre.h | 342 +++--- ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 1018 ++++++++--------- ucm/sparse/gsa/prefetch/src/pybinds.cpp | 45 +- 5 files changed, 751 insertions(+), 829 deletions(-) diff --git a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp index fe7afa5a..65852019 100644 --- a/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp +++ b/ucm/sparse/gsa/offload_ops/src/select_topk_block.cpp @@ -1,37 +1,31 @@ +#include "select_topk_block.h" #include -#include +#include #include #include -#include -#include "select_topk_block.h" +#include namespace SelectTopkBlock { #define OMP_THREAD_NUM 16u -bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize) +bool TopkBlockSelector::ValidateParameters(float* q, const float* kRepre, uint32_t numBlock, + uint32_t kHead, uint32_t qHead, uint32_t numKrepre, + uint32_t headSize) { - return (q != nullptr) && (kRepre != nullptr) && - (numBlock > 0) && (kHead > 0) && (qHead > 0) && + return (q != nullptr) && (kRepre != nullptr) && (numBlock > 0) && (kHead > 0) && (qHead > 0) && (numKrepre > 0) && (headSize > 0); } -void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k, int32_t* topkIndices) +void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32_t k, + int32_t* topkIndices) { if (startWindow_ + endWindow_ >= numScores || k >= numScores || k == 0) { - for (uint32_t i = 0; i < numScores; ++i) { - topkIndices[i] = i; - } + for (uint32_t i = 0; i < numScores; ++i) { topkIndices[i] = i; } return; } uint32_t idx = 0; - for (uint32_t i = 0; i < startWindow_; ++i) { - topkIndices[idx++] = i; - } - for (uint32_t i = 0; i < endWindow_; ++i) { - topkIndices[idx++] = numScores - endWindow_ + i; - } + for (uint32_t i = 0; i < startWindow_; ++i) { topkIndices[idx++] = i; } + for (uint32_t i = 0; i < endWindow_; ++i) { topkIndices[idx++] = numScores - endWindow_ + i; } int32_t midCount = k - startWindow_ - endWindow_; if (midCount > 0) { std::vector middleIndices; @@ -39,19 +33,16 @@ void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32 for (uint32_t i = startWindow_; i < numScores - endWindow_; ++i) { middleIndices.push_back(i); } - std::stable_sort(middleIndices.begin(), middleIndices.end(), - [scores](uint32_t lhs, uint32_t rhs) { - return scores[lhs] > scores[rhs]; - }); - for (int32_t i = 0; i < midCount; ++i) { - topkIndices[idx++] = middleIndices[i]; - } + std::stable_sort( + middleIndices.begin(), middleIndices.end(), + [scores](uint32_t lhs, uint32_t rhs) { return scores[lhs] > scores[rhs]; }); + for (int32_t i = 0; i < midCount; ++i) { topkIndices[idx++] = middleIndices[i]; } } } -float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, - uint32_t kHead, uint32_t numKrepre, - uint32_t headSize, const VecProductClass& vecProduct) +float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase, uint32_t kHead, + uint32_t numKrepre, uint32_t headSize, + const VecProductClass& vecProduct) { const size_t headOffset = headSize; const size_t normOffset = headSize; @@ -80,8 +71,10 @@ const VecProductClass& TopkBlockSelector::ThreadLocalVecProduct::GetInstance() return instance; } -std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean, const float* __restrict kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t numKrepre, uint32_t headSize) +std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict qMean, + const float* __restrict kRepre, + uint32_t numBlock, uint32_t kHead, + uint32_t numKrepre, uint32_t headSize) { std::vector blockScores(numBlock, 0.0f); const size_t blockOffset = static_cast(kHead * numKrepre * headSize); @@ -92,16 +85,16 @@ std::vector TopkBlockSelector::ComputeKQDotScores(const float* __restrict if (idxBlock + 1 < numBlock) { __builtin_prefetch(kRepre + (idxBlock + 1) * blockOffset, 0, 1); } - blockScores[idxBlock] = ComputeBlockScore(const_cast(qMean), blockBase, kHead, numKrepre, headSize, vecProduct); + blockScores[idxBlock] = ComputeBlockScore(const_cast(qMean), blockBase, kHead, + numKrepre, headSize, vecProduct); } return blockScores; } -void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead, uint32_t headSize) +void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, uint32_t qHead, + uint32_t headSize) { - if (kHead == qHead) { - return; - } + if (kHead == qHead) { return; } const VecProductClass& vecProduct = ThreadLocalVecProduct::GetInstance(); const uint32_t groupSize = qHead / kHead; for (uint32_t kIdx = 0; kIdx < kHead; ++kIdx) { @@ -112,28 +105,25 @@ void TopkBlockSelector::ComputeQHeadMean(float* __restrict q, uint32_t kHead, ui } } -void TopkBlockSelector::SelectTopK(float* q, const float* kRepre, - uint32_t numBlock, uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize, +void TopkBlockSelector::SelectTopK(float* q, const float* kRepre, uint32_t numBlock, uint32_t kHead, + uint32_t qHead, uint32_t numKrepre, uint32_t headSize, uint32_t topkLength, int32_t* topkResult) { if (!ValidateParameters(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize) || topkResult == nullptr || topkLength == 0) { - return; + return; } ComputeQHeadMean(q, kHead, qHead, headSize); - const std::vector scores = ComputeKQDotScores(q, kRepre, numBlock, - kHead, numKrepre, headSize); + const std::vector scores = + ComputeKQDotScores(q, kRepre, numBlock, kHead, numKrepre, headSize); TopKImpl(scores.data(), numBlock, topkLength, topkResult); } void TopkBlockSelector::SelectTopKBS(const std::vector& qCacheVec, const std::vector& kfCacheVec, - const std::vector& topkCacheVec, - uint32_t numBatch, - const std::vector& numBlockVec, - uint32_t kHead, uint32_t qHead, - uint32_t numKrepre, uint32_t headSize, + const std::vector& topkCacheVec, uint32_t numBatch, + const std::vector& numBlockVec, uint32_t kHead, + uint32_t qHead, uint32_t numKrepre, uint32_t headSize, const std::vector& topkLengthVec) { for (uint32_t bs = 0; bs < numBatch; ++bs) { @@ -142,9 +132,8 @@ void TopkBlockSelector::SelectTopKBS(const std::vector& qCacheVec, float* q = qCacheVec[bs]; const float* kRepre = kfCacheVec[bs]; int32_t* topkResult = topkCacheVec[bs]; - SelectTopK(q, kRepre, numBlock, kHead, qHead, - numKrepre, headSize, topkLength, topkResult); + SelectTopK(q, kRepre, numBlock, kHead, qHead, numKrepre, headSize, topkLength, topkResult); } } -} \ No newline at end of file +} // namespace SelectTopkBlock \ No newline at end of file diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_log.h b/ucm/sparse/gsa/prefetch/include/kvcache_log.h index 38caee9d..7d446ca3 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_log.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_log.h @@ -1,22 +1,17 @@ #ifndef ATB_KV_LOG_H #define ATB_KV_LOG_H -#include -#include -#include #include -#include -#include +#include #include +#include +#include #include -enum class LogLevel { - DEBUG, - INFO, - WARNING, - ERROR -}; +#include +#include +#include +enum class LogLevel { DEBUG, INFO, WARNING, ERROR }; -class Logger -{ +class Logger { private: std::ofstream mLogFile; LogLevel mMinLevel; @@ -25,13 +20,12 @@ class Logger static std::string LevelToString(LogLevel level) { - switch (level) - { - case LogLevel::DEBUG: return "DEBUG"; - case LogLevel::INFO: return "INFO"; - case LogLevel::WARNING: return "WARNING"; - case LogLevel::ERROR: return "ERROR"; - default: return "UNKNOWN"; + switch (level) { + case LogLevel::DEBUG: return "DEBUG"; + case LogLevel::INFO: return "INFO"; + case LogLevel::WARNING: return "WARNING"; + case LogLevel::ERROR: return "ERROR"; + default: return "UNKNOWN"; } } @@ -39,8 +33,8 @@ class Logger { auto now = std::chrono::system_clock::now(); auto nowC = std::chrono::system_clock::to_time_t(now); - auto ms = std::chrono::duration_cast( - now.time_since_epoch()) % 1000; + auto ms = + std::chrono::duration_cast(now.time_since_epoch()) % 1000; std::stringstream oss; oss << std::put_time(std::localtime(&nowC), "%Y-%m-%d %H:%M:%S"); oss << '.' << std::setfill('0') << std::setw(3) << ms.count(); @@ -48,8 +42,8 @@ class Logger } public: - Logger(const std::string &fileName, LogLevel level = LogLevel::INFO, bool enable = true) - :mMinLevel(level), mEnable(enable) + Logger(const std::string& fileName, LogLevel level = LogLevel::INFO, bool enable = true) + : mMinLevel(level), mEnable(enable) { if (enable) { mLogFile.open(fileName, std::ios::app); @@ -59,43 +53,37 @@ class Logger } } - Logger(){} + Logger() {} ~Logger() { - if (mLogFile.is_open()) { - mLogFile.close(); - } + if (mLogFile.is_open()) { mLogFile.close(); } } - void SetLevel(LogLevel level) - { - mMinLevel = level; - } + void SetLevel(LogLevel level) { mMinLevel = level; } void log(LogLevel level, const char* format, ...) { - if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { - return; - } + if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; } std::lock_guard lock(mMutex); auto now = std::chrono::system_clock::now(); auto nowC = std::chrono::system_clock::to_time_t(now); auto duration = now.time_since_epoch(); - auto millis = std::chrono::duration_cast(duration).count() % 1000; - auto micros = std::chrono::duration_cast(duration).count() % 1000; + auto millis = + std::chrono::duration_cast(duration).count() % 1000; + auto micros = + std::chrono::duration_cast(duration).count() % 1000; std::tm localTime = *std::localtime(&nowC); char timeBuffer[26]; std::strftime(timeBuffer, sizeof(timeBuffer), "%Y-%m-%d %H:%M:%S", &localTime); - const char *levelStr = ""; - switch (level) - { - case LogLevel::DEBUG: levelStr = "DEBUG"; break; - case LogLevel::INFO: levelStr = "INFO"; break; - case LogLevel::WARNING: levelStr = "WARNING"; break; - case LogLevel::ERROR: levelStr = "ERROR"; break; - default: levelStr = "UNKNOWN"; break; + const char* levelStr = ""; + switch (level) { + case LogLevel::DEBUG: levelStr = "DEBUG"; break; + case LogLevel::INFO: levelStr = "INFO"; break; + case LogLevel::WARNING: levelStr = "WARNING"; break; + case LogLevel::ERROR: levelStr = "ERROR"; break; + default: levelStr = "UNKNOWN"; break; } char messageBuffer[4096]; va_list args; @@ -103,18 +91,14 @@ class Logger vsnprintf(messageBuffer, sizeof(messageBuffer), format, args); va_end(args); - mLogFile << timeBuffer << "." - << std::setfill('0') << std::setw(3) << millis << std::setw(3) - << micros << " " << "[" << levelStr << "]" - << messageBuffer; + mLogFile << timeBuffer << "." << std::setfill('0') << std::setw(3) << millis << std::setw(3) + << micros << " " << "[" << levelStr << "]" << messageBuffer; mLogFile.flush(); } void LogWOPrefix(LogLevel level, const char* format, ...) { - if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { - return; - } + if (level < mMinLevel || !mLogFile.is_open() || !mEnable) { return; } std::lock_guard lock(mMutex); char messageBuffer[2048]; va_list args; diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h index d0389759..0a8e0b02 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h @@ -1,199 +1,179 @@ #ifndef ATB_KV_CACHE_PRE_H #define ATB_KV_CACHE_PRE_H -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include #include -#include -#include -#include +#include #include -#include -#include +#include +#include +#include #include -#include #include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include #include "../../../../store/ucmstore.h" namespace py = pybind11; -namespace ucmprefetch -{ - typedef struct { - int topkLen; - std::string reqID; - int layerID; - int topkIndex; - int bsIndex; - } PrefetchReqInfo; - - class ThreadPool - { - public: - static ThreadPool *GetInst() - { - static ThreadPool pool(1); - return &pool; - } - - ~ThreadPool(); - - template - auto enqueue(F&& f, Args&&... args) -> std::future::type>; - - size_t GetActiveThreads() const; - - private: - explicit ThreadPool(size_t threadCount); - std::vector workers; - std::queue> tasks; - mutable std::mutex queueMutex; - bool stop; - std::condition_variable condition; - std::atomic activeThreads{0}; - size_t maxThreads; - }; - - void MutliBSThreadFun(void *args); - - class __attribute__((visibility("hidden"))) GSAPrefetchEngineC +namespace ucmprefetch { +typedef struct { + int topkLen; + std::string reqID; + int layerID; + int topkIndex; + int bsIndex; +} PrefetchReqInfo; + +class ThreadPool { +public: + static ThreadPool* GetInst() { - private: - std::map>> mDocsTables; - std::map>> mBlocksMap; - torch::Tensor mLoadSuccessBlocks; - torch::Tensor mFreeBlock; - torch::Tensor mFreeBlockLen; - torch::Tensor mSuccessTableLen; - torch::Tensor mUseTopkIdxs; - int mLayerNum; - int mRank = -1; - uint32_t mMaxBs = 30; - std::vector mReqIdList; - int *mTopkLenList = NULL; - int *mBsIndexList = NULL; - uint32_t runBsLen = 0; - bool mIsLog = false; - bool mIsPrefetchDone = true; - bool mUseMla = false; - Logger mLogger; - ThreadPool *mThreadPool; - uint32_t mDecodeStep = 0; - uint32_t mMaxTopkLen = 0; - uint32_t mMaxBlocksLen = 0; - std::unordered_set mDelSeqIds; - std::map>> allNeedLoadBlock; - std::map>> allMissIdxs; - std::map mPromptLen; - UC::CCStore<> *mStore = nullptr; - std::vector mKvCaches; - uint32_t mBlockSize = 128; - uint32_t mTensorElemSize = 2; // fp16 - uint32_t mHeadNum = 40; - uint32_t mHeadSzie = 128; - uint32_t mTPSize = 2; - std::map> mAllBlcoksHash; - uint32_t mKVSzieBytes = 0; - uint32_t mExtraTopkLen = 16; - bool mIsPythonLoad = false; - public: - std::mutex mMutex; - bool mStopPrefetch = false; - - private: - void LoadKVToHBM(std::vector loadNPUBlockIDs, - std::vector missIdxs, int layerID, std::string reqID); - - void GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs); - - void RunPrefetchH2D(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs); - - void RunOneBsPrefetch(std::string reqID, int topkLen, - int bsIndex, int topkIndex); - - public: - ~GSAPrefetchEngineC(); - - GSAPrefetchEngineC(torch::Tensor &freeBlock, - torch::Tensor &loadSuccessBlocks, - torch::Tensor &freeBlockLen, - torch::Tensor &successTableLen, - std::vector &kvShape, - bool useMla, - bool isLog, - int tpSize, - int rank, - int extraTopkLen, - bool isPythonLoad - ); - - void SetBlocksMap(std::string reqID, std::vector &blockTableList, - std::vector &selectIndex, std::vector &blocksHash, - int maxIdx); - - void SetBlocksMapMultiLayer(std::string reqID, - std::vector> &remainMap, - std::vector> &prefetchMap, - std::vector &blocksHash, - int maxIdx); - - void CheckInputIndex(uint32_t maxLen, uint32_t index); - - void AddBlocksMap(std::string reqID, int idx, int blockID); - - void DelBlocksMap(std::string reqID); - - void DelReqIDRun(); - - void SetBlockTableInfo(torch::Tensor &blockTables, - torch::Tensor &blockLengths, - torch::Tensor &inputTopkBuf, int step); - - void RunAsyncPrefetchBs(std::vector &reqIDsInput, - std::vector &topkLensInput, - std::vector &bsIndexInput, - std::vector &kvCaches, - void *storePtr); - - int CallPrefetchProcessFun(); - - void PrintMap(std::string reqID, int i); - - bool GetPrefetchStatus(); - - void SetPrefetchStatus(bool flag); - - void SetModelRunningStatus(bool flag); - - size_t GetOffset(uint32_t layerID, bool isV); - - size_t GetOffsetNew(uint32_t layerID, bool isV); - - std::map>> ObtainLoadBlocks(); - - std::map>> ObtainMissIdxs(); + static ThreadPool pool(1); + return &pool; + } + + ~ThreadPool(); + + template + auto enqueue(F&& f, Args&&... args) -> std::future::type>; + + size_t GetActiveThreads() const; + +private: + explicit ThreadPool(size_t threadCount); + std::vector workers; + std::queue> tasks; + mutable std::mutex queueMutex; + bool stop; + std::condition_variable condition; + std::atomic activeThreads{0}; + size_t maxThreads; +}; + +void MutliBSThreadFun(void* args); + +class __attribute__((visibility("hidden"))) GSAPrefetchEngineC { +private: + std::map>> mDocsTables; + std::map>> mBlocksMap; + torch::Tensor mLoadSuccessBlocks; + torch::Tensor mFreeBlock; + torch::Tensor mFreeBlockLen; + torch::Tensor mSuccessTableLen; + torch::Tensor mUseTopkIdxs; + int mLayerNum; + int mRank = -1; + uint32_t mMaxBs = 30; + std::vector mReqIdList; + int* mTopkLenList = NULL; + int* mBsIndexList = NULL; + uint32_t runBsLen = 0; + bool mIsLog = false; + bool mIsPrefetchDone = true; + bool mUseMla = false; + Logger mLogger; + ThreadPool* mThreadPool; + uint32_t mDecodeStep = 0; + uint32_t mMaxTopkLen = 0; + uint32_t mMaxBlocksLen = 0; + std::unordered_set mDelSeqIds; + std::map>> allNeedLoadBlock; + std::map>> allMissIdxs; + std::map mPromptLen; + UC::CCStore<>* mStore = nullptr; + std::vector mKvCaches; + uint32_t mBlockSize = 128; + uint32_t mTensorElemSize = 2; // fp16 + uint32_t mHeadNum = 40; + uint32_t mHeadSzie = 128; + uint32_t mTPSize = 2; + std::map> mAllBlcoksHash; + uint32_t mKVSzieBytes = 0; + uint32_t mExtraTopkLen = 16; + bool mIsPythonLoad = false; + +public: + std::mutex mMutex; + bool mStopPrefetch = false; + +private: + void LoadKVToHBM(std::vector loadNPUBlockIDs, std::vector missIdxs, int layerID, + std::string reqID); + + void GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, std::vector& missIdxs); + + void RunPrefetchH2D(PrefetchReqInfo oneBsInfo, std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, std::vector& missIdxs); + + void RunOneBsPrefetch(std::string reqID, int topkLen, int bsIndex, int topkIndex); + +public: + ~GSAPrefetchEngineC(); + + GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor& loadSuccessBlocks, + torch::Tensor& freeBlockLen, torch::Tensor& successTableLen, + std::vector& kvShape, bool useMla, bool isLog, int tpSize, + int rank, int extraTopkLen, bool isPythonLoad); + + void SetBlocksMap(std::string reqID, std::vector& blockTableList, + std::vector& selectIndex, std::vector& blocksHash, + int maxIdx); + + void SetBlocksMapMultiLayer(std::string reqID, std::vector>& remainMap, + std::vector>& prefetchMap, + std::vector& blocksHash, int maxIdx); + + void CheckInputIndex(uint32_t maxLen, uint32_t index); + + void AddBlocksMap(std::string reqID, int idx, int blockID); + + void DelBlocksMap(std::string reqID); + + void DelReqIDRun(); + + void SetBlockTableInfo(torch::Tensor& blockTables, torch::Tensor& blockLengths, + torch::Tensor& inputTopkBuf, int step); + + void RunAsyncPrefetchBs(std::vector& reqIDsInput, std::vector& topkLensInput, + std::vector& bsIndexInput, std::vector& kvCaches, + void* storePtr); + + int CallPrefetchProcessFun(); + + void PrintMap(std::string reqID, int i); + + bool GetPrefetchStatus(); + + void SetPrefetchStatus(bool flag); + + void SetModelRunningStatus(bool flag); + + size_t GetOffset(uint32_t layerID, bool isV); + + size_t GetOffsetNew(uint32_t layerID, bool isV); + + std::map>> ObtainLoadBlocks(); + + std::map>> ObtainMissIdxs(); - std::map>> ObtainBlocksMap(); + std::map>> ObtainBlocksMap(); - std::map>> ObtainDocsMap(); - }; + std::map>> ObtainDocsMap(); +}; -} // namespace uc +} // namespace ucmprefetch #endif diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index a6f61b88..d7670afa 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -1,609 +1,579 @@ #include "kvcache_pre.h" #include -#include #include +#include -namespace ucmprefetch +namespace ucmprefetch { +ThreadPool::ThreadPool(size_t threadCount) : stop(false), maxThreads(threadCount) { - ThreadPool::ThreadPool(size_t threadCount) - :stop(false), maxThreads(threadCount) - { - for (size_t i = 0; i < maxThreads; i++) { - workers.emplace_back([this] { - while(true) { - std::function task; - { - std::unique_lock lock(this->queueMutex); - this->condition.wait(lock, [this] { - return this->stop || !this->tasks.empty(); - }); - - if (this->stop && this->tasks.empty()) { - return; - } - - task = std::move(this->tasks.front()); - this->tasks.pop(); - ++activeThreads; - } - - task(); - { - std::unique_lock lock(this->queueMutex); - --activeThreads; - condition.notify_all(); - } + for (size_t i = 0; i < maxThreads; i++) { + workers.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(this->queueMutex); + this->condition.wait(lock, + [this] { return this->stop || !this->tasks.empty(); }); + + if (this->stop && this->tasks.empty()) { return; } + + task = std::move(this->tasks.front()); + this->tasks.pop(); + ++activeThreads; } - }); - } + + task(); + { + std::unique_lock lock(this->queueMutex); + --activeThreads; + condition.notify_all(); + } + } + }); } - ThreadPool::~ThreadPool() +} +ThreadPool::~ThreadPool() +{ { - { - std::unique_lock lock(queueMutex); - stop = true; - } - condition.notify_all(); - for (std::thread &worker : workers) { - worker.join(); - } + std::unique_lock lock(queueMutex); + stop = true; } + condition.notify_all(); + for (std::thread& worker : workers) { worker.join(); } +} - template - auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> - { - using return_type = typename std::result_of::type; +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + using return_type = typename std::result_of::type; - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...) - ); + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); - std::future res = task->get_future(); - { - std::unique_lock lock(queueMutex); + std::future res = task->get_future(); + { + std::unique_lock lock(queueMutex); - condition.wait(lock, [this] { - if (!(activeThreads < maxThreads || tasks.size() < maxThreads * 2)) { - std::cout << "Need wait: " << activeThreads << " " << tasks.size() << std::endl; - } - return (activeThreads < maxThreads || tasks.size() < maxThreads * 2); - }); - // don't allow enqueueing after stopping the pool - if(stop) { - throw std::runtime_error("enqueue on stopped ThreadPool"); + condition.wait(lock, [this] { + if (!(activeThreads < maxThreads || tasks.size() < maxThreads * 2)) { + std::cout << "Need wait: " << activeThreads << " " << tasks.size() << std::endl; } + return (activeThreads < maxThreads || tasks.size() < maxThreads * 2); + }); + // don't allow enqueueing after stopping the pool + if (stop) { throw std::runtime_error("enqueue on stopped ThreadPool"); } - tasks.emplace([task](){ (*task)(); }); - } - condition.notify_one(); - return res; + tasks.emplace([task]() { (*task)(); }); } + condition.notify_one(); + return res; +} - size_t ThreadPool::GetActiveThreads() const - { - return activeThreads; +size_t ThreadPool::GetActiveThreads() const { return activeThreads; } + +void MutliBSThreadFun(void* args) +{ + GSAPrefetchEngineC* engine = static_cast(args); + int ret = engine->CallPrefetchProcessFun(); + engine->mMutex.lock(); + engine->DelReqIDRun(); + engine->mMutex.unlock(); + if (ret == 0) { engine->SetPrefetchStatus(true); } +} + +GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor& freeBlock, torch::Tensor& loadSuccessBlocks, + torch::Tensor& freeBlockLen, torch::Tensor& successTableLen, + std::vector& kvShape, bool useMla, bool isLog, + int tpSize, int rank, int extraTopkLen, bool isPythonLoad) + : mLogger("./log/kvcache_pre_log.txt", LogLevel::INFO, isLog) +{ + mLoadSuccessBlocks = loadSuccessBlocks; + mLayerNum = mLoadSuccessBlocks.sizes()[0]; + mMaxBs = mLoadSuccessBlocks.sizes()[1]; + mMaxTopkLen = mLoadSuccessBlocks.sizes()[2]; + mFreeBlock = freeBlock; + mFreeBlockLen = freeBlockLen; + mSuccessTableLen = successTableLen; + mIsLog = isLog; + mBsIndexList = (int*)malloc(sizeof(int) * mMaxBs); + mTopkLenList = (int*)malloc(sizeof(int) * mMaxBs); + mIsPrefetchDone = true; + mThreadPool = ThreadPool::GetInst(); + mUseMla = useMla; + mHeadSzie = kvShape[2]; + mHeadNum = kvShape[1]; + mBlockSize = kvShape[0]; + mTPSize = tpSize; + mRank = rank; + mIsPythonLoad = isPythonLoad; + if (mRank != 0) { + mLogger.SetLevel(LogLevel::WARNING); + mIsLog = false; + } + mExtraTopkLen = extraTopkLen; + mLogger.log(LogLevel::INFO, + "GSAPrefetchEngineC Init mLayerNum %d mMaxBs %u, mUseMla %d, mHeadSzie %u, mTPSize " + "%u mBlockSize %u mHeadNum %u\n", + mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum); +} + +size_t GSAPrefetchEngineC::GetOffset(uint32_t layerID, bool isV) +{ + size_t kMinDataBlockSize = + static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; + size_t vMinDataBlockSize = kMinDataBlockSize; + size_t layerSize = (kMinDataBlockSize + vMinDataBlockSize) * mTPSize; + if (mUseMla) { + vMinDataBlockSize = 0; + layerSize = kMinDataBlockSize; + } + size_t kOffset = 0; + if (mUseMla) { + kOffset = layerSize * layerID; + } else { + kOffset = layerSize * layerID + layerSize / mTPSize * mRank; } + size_t vOffset = kOffset + kMinDataBlockSize; + if (isV) { + return vOffset; + } else { + return kOffset; + } +} - void MutliBSThreadFun(void *args) - { - GSAPrefetchEngineC *engine = static_cast(args); - int ret = engine->CallPrefetchProcessFun(); - engine->mMutex.lock(); - engine->DelReqIDRun(); - engine->mMutex.unlock(); - if (ret == 0) { - engine->SetPrefetchStatus(true); - } +size_t GSAPrefetchEngineC::GetOffsetNew(uint32_t layerID, bool isV) +{ + size_t kMinDataBlockSize = + static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; + size_t layerSize = kMinDataBlockSize * 2; + size_t kOffset = layerSize * layerID; + if (mUseMla) { + layerSize = kMinDataBlockSize; + kOffset = layerSize * layerID; + return kOffset; } + size_t vOffset = kOffset + kMinDataBlockSize; - GSAPrefetchEngineC::GSAPrefetchEngineC(torch::Tensor &freeBlock, - torch::Tensor &loadSuccessBlocks, - torch::Tensor &freeBlockLen, - torch::Tensor &successTableLen, - std::vector &kvShape, - bool useMla, - bool isLog, - int tpSize, - int rank, - int extraTopkLen, - bool isPythonLoad - ) - :mLogger("./log/kvcache_pre_log.txt", LogLevel::INFO, isLog) - { - mLoadSuccessBlocks = loadSuccessBlocks; - mLayerNum = mLoadSuccessBlocks.sizes()[0]; - mMaxBs = mLoadSuccessBlocks.sizes()[1]; - mMaxTopkLen = mLoadSuccessBlocks.sizes()[2]; - mFreeBlock = freeBlock; - mFreeBlockLen = freeBlockLen; - mSuccessTableLen = successTableLen; - mIsLog = isLog; - mBsIndexList = (int *)malloc(sizeof(int) * mMaxBs); - mTopkLenList = (int *)malloc(sizeof(int) * mMaxBs); - mIsPrefetchDone = true; - mThreadPool = ThreadPool::GetInst(); - mUseMla = useMla; - mHeadSzie = kvShape[2]; - mHeadNum = kvShape[1]; - mBlockSize = kvShape[0]; - mTPSize = tpSize; - mRank = rank; - mIsPythonLoad = isPythonLoad; - if(mRank != 0) { - mLogger.SetLevel(LogLevel::WARNING); - mIsLog = false; - } - mExtraTopkLen = extraTopkLen; - mLogger.log(LogLevel::INFO, - "GSAPrefetchEngineC Init mLayerNum %d mMaxBs %u, mUseMla %d, mHeadSzie %u, mTPSize %u mBlockSize %u mHeadNum %u\n", - mLayerNum, mMaxBs, mUseMla, mHeadSzie, mTPSize, mBlockSize, mHeadNum); + if (isV) { + return vOffset; + } else { + return kOffset; } +} - size_t GSAPrefetchEngineC::GetOffset(uint32_t layerID, bool isV) - { - size_t kMinDataBlockSize = static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; - size_t vMinDataBlockSize = kMinDataBlockSize; - size_t layerSize = (kMinDataBlockSize + vMinDataBlockSize) * mTPSize; - if (mUseMla) { - vMinDataBlockSize = 0; - layerSize = kMinDataBlockSize; - } - size_t kOffset = 0; - if (mUseMla) { - kOffset = layerSize * layerID; - } else { - kOffset = layerSize * layerID + layerSize / mTPSize * mRank; - } - size_t vOffset = kOffset + kMinDataBlockSize; - if (isV) { - return vOffset; - } else { - return kOffset; - } +void GSAPrefetchEngineC::CheckInputIndex(uint32_t maxLen, uint32_t index) +{ + if (index >= maxLen) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, |KVCache Prefetch| index error! index: %u, maxLen: %u\n", + mDecodeStep, index, maxLen); + std::abort(); } +} - size_t GSAPrefetchEngineC::GetOffsetNew(uint32_t layerID, bool isV) - { - size_t kMinDataBlockSize = static_cast(mBlockSize) * mHeadNum * mHeadSzie * mTensorElemSize; - size_t layerSize = kMinDataBlockSize * 2; - size_t kOffset = layerSize * layerID; - if (mUseMla) { - layerSize = kMinDataBlockSize; - kOffset = layerSize * layerID; - return kOffset; - } - size_t vOffset = kOffset + kMinDataBlockSize; +GSAPrefetchEngineC::~GSAPrefetchEngineC() +{ + free(mBsIndexList); + free(mTopkLenList); +} - if (isV) { - return vOffset; - } else { - return kOffset; - } +void GSAPrefetchEngineC::SetBlocksMap(std::string reqID, std::vector& blockTableList, + std::vector& selectIndex, + std::vector& blocksHash, int maxIdx) +{ + if (mBlocksMap.find(reqID) != mBlocksMap.end()) { + mBlocksMap[reqID].clear(); + mDocsTables[reqID].clear(); + mAllBlcoksHash[reqID].clear(); } - - void GSAPrefetchEngineC::CheckInputIndex(uint32_t maxLen, uint32_t index) - { - if (index >= maxLen) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| index error! index: %u, maxLen: %u\n", - mDecodeStep, index, maxLen); - std::abort(); + mAllBlcoksHash[reqID] = blocksHash; + for (int i = 0; i < mLayerNum; i++) { + std::map oneDocTable; + std::map oneBlockMap; + for (auto idx : selectIndex) { + oneDocTable[idx] = blockTableList[idx]; + oneBlockMap[blockTableList[idx]] = idx; } + mDocsTables[reqID].push_back(oneDocTable); + mBlocksMap[reqID].push_back(oneBlockMap); } - - GSAPrefetchEngineC::~GSAPrefetchEngineC() - { - free(mBsIndexList); - free(mTopkLenList); + mPromptLen[reqID] = maxIdx; + PrintMap(reqID, 0); +} + +void GSAPrefetchEngineC::SetBlocksMapMultiLayer(std::string reqID, + std::vector>& remainMap, + std::vector>& prefetchMap, + std::vector& blocksHash, int maxIdx) +{ + if (mBlocksMap.find(reqID) != mBlocksMap.end()) { + mBlocksMap[reqID].clear(); + mDocsTables[reqID].clear(); + mAllBlcoksHash[reqID].clear(); } - - void GSAPrefetchEngineC::SetBlocksMap(std::string reqID, std::vector &blockTableList, - std::vector &selectIndex, std::vector &blocksHash, int maxIdx) - { - if (mBlocksMap.find(reqID) != mBlocksMap.end()) { - mBlocksMap[reqID].clear(); - mDocsTables[reqID].clear(); - mAllBlcoksHash[reqID].clear(); + mAllBlcoksHash[reqID] = blocksHash; + for (int i = 0; i < mLayerNum; i++) { + std::map oneDocTable; + std::map oneBlockMap; + for (auto it = remainMap[i].begin(); it != remainMap[i].end(); it++) { + oneDocTable[it->first] = it->second; + oneBlockMap[it->second] = it->first; } - mAllBlcoksHash[reqID] = blocksHash; - for (int i = 0; i < mLayerNum; i++) { - std::map oneDocTable; - std::map oneBlockMap; - for (auto idx:selectIndex) { - oneDocTable[idx] = blockTableList[idx]; - oneBlockMap[blockTableList[idx]] = idx; - } - mDocsTables[reqID].push_back(oneDocTable); - mBlocksMap[reqID].push_back(oneBlockMap); + for (auto it = prefetchMap[i].begin(); it != prefetchMap[i].end(); it++) { + oneDocTable[it->first] = it->second; + oneBlockMap[it->second] = it->first; } - mPromptLen[reqID] = maxIdx; - PrintMap(reqID, 0); + mDocsTables[reqID].push_back(oneDocTable); + mBlocksMap[reqID].push_back(oneBlockMap); } + mPromptLen[reqID] = maxIdx; +} - void GSAPrefetchEngineC::SetBlocksMapMultiLayer(std::string reqID, std::vector> &remainMap, - std::vector> &prefetchMap, std::vector &blocksHash, int maxIdx) - { - if (mBlocksMap.find(reqID) != mBlocksMap.end()) { - mBlocksMap[reqID].clear(); - mDocsTables[reqID].clear(); - mAllBlcoksHash[reqID].clear(); - } - mAllBlcoksHash[reqID] = blocksHash; - for (int i = 0; i < mLayerNum; i++) { +void GSAPrefetchEngineC::AddBlocksMap(std::string reqID, int idx, int blockID) +{ + if (mBlocksMap.find(reqID) == mBlocksMap.end()) { + for (int i = 0; i < mLayerNum; ++i) { std::map oneDocTable; std::map oneBlockMap; - for (auto it = remainMap[i].begin(); it != remainMap[i].end(); it++) { - oneDocTable[it->first] = it->second; - oneBlockMap[it->second] = it->first; - } - for (auto it = prefetchMap[i].begin(); it != prefetchMap[i].end(); it++) { - oneDocTable[it->first] = it->second; - oneBlockMap[it->second] = it->first; - } + oneDocTable[idx] = blockID; + oneBlockMap[blockID] = idx; mDocsTables[reqID].push_back(oneDocTable); mBlocksMap[reqID].push_back(oneBlockMap); } - mPromptLen[reqID] = maxIdx; + } else { + for (int i = 0; i < mLayerNum; i++) { + mDocsTables[reqID][i][idx] = blockID; + mBlocksMap[reqID][i][blockID] = idx; + } } +} - void GSAPrefetchEngineC::AddBlocksMap(std::string reqID, int idx, int blockID) - { - if (mBlocksMap.find(reqID) == mBlocksMap.end()) { - for (int i = 0; i < mLayerNum; ++i) { - std::map oneDocTable; - std::map oneBlockMap; - oneDocTable[idx] = blockID; - oneBlockMap[blockID] = idx; - mDocsTables[reqID].push_back(oneDocTable); - mBlocksMap[reqID].push_back(oneBlockMap); - } +void GSAPrefetchEngineC::DelBlocksMap(std::string reqID) +{ + mMutex.lock(); + mDelSeqIds.insert(reqID); + if (mIsPrefetchDone) { DelReqIDRun(); } + mMutex.unlock(); +} + +void GSAPrefetchEngineC::DelReqIDRun() +{ + for (auto it = mDelSeqIds.begin(); it != mDelSeqIds.end(); it++) { + if (mBlocksMap.find(*it) == mBlocksMap.end()) { + continue; } else { - for (int i = 0; i < mLayerNum; i++) { - mDocsTables[reqID][i][idx] = blockID; - mBlocksMap[reqID][i][blockID] = idx; - } + mBlocksMap.erase(*it); + mDocsTables.erase(*it); + mAllBlcoksHash.erase(*it); + mPromptLen.erase(*it); + std::cout << "Del reqID: " << *it << std::endl; } - } - - void GSAPrefetchEngineC::DelBlocksMap(std::string reqID) - { - mMutex.lock(); - mDelSeqIds.insert(reqID); - if (mIsPrefetchDone) { - DelReqIDRun(); + if (mPromptLen.find(*it) == mPromptLen.end()) { + continue; + } else { + mPromptLen.erase(*it); } - mMutex.unlock(); } + mDelSeqIds.clear(); +} - void GSAPrefetchEngineC::DelReqIDRun() - { - for (auto it = mDelSeqIds.begin(); it != mDelSeqIds.end(); it++) { - if (mBlocksMap.find(*it) == mBlocksMap.end()) { - continue; - } else { - mBlocksMap.erase(*it); - mDocsTables.erase(*it); - mAllBlcoksHash.erase(*it); - mPromptLen.erase(*it); - std::cout << "Del reqID: " << *it << std::endl; - } - if (mPromptLen.find(*it) == mPromptLen.end()) { - continue; - } else { - mPromptLen.erase(*it); - } +void GSAPrefetchEngineC::PrintMap(std::string reqID, int i) +{ + std::ostringstream oss; + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << i << "mDocsTables"; + for (auto it : mDocsTables[reqID][i]) { oss << "(" << it.first << ", " << it.second << ")"; } + oss << "------\n"; + mLogger.log(LogLevel::INFO, oss.str().c_str()); + oss.str(""); + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << i << "mBlocksMap"; + for (auto it : mBlocksMap[reqID][i]) { oss << "(" << it.first << ", " << it.second << ")"; } + oss << "------\n"; + mLogger.log(LogLevel::INFO, oss.str().c_str()); + oss.str(""); +} + +void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, + std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, + std::vector& missIdxs) +{ + int topkLen = oneBsInfo.topkLen; + int layerID = oneBsInfo.layerID; + std::string reqID = oneBsInfo.reqID; + int topkIndex = oneBsInfo.topkIndex; + + std::ostringstream oss; + oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " << reqID + << " layerID: " << layerID << " topk len: " << topkLen << " topk: "; + for (int j = 0; j < topkLen; j++) { + int64_t item = 0; + if (mUseTopkIdxs.scalar_type() == torch::kInt32) { + item = mUseTopkIdxs[layerID][topkIndex][j].item(); + } else { + item = mUseTopkIdxs[layerID][topkIndex][j].item(); + } + oss << item << " "; + if (mDocsTables[reqID][layerID].find(item) != mDocsTables[reqID][layerID].end()) { + int blockID = mDocsTables[reqID][layerID][item]; + hitBlocks.insert(blockID); + hitBlocksIdx.insert(std::make_pair(item, blockID)); + if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { break; } + } else { + missIdxs.push_back(item); } - mDelSeqIds.clear(); } + oss << "------\n"; + mLogger.log(LogLevel::DEBUG, oss.str().c_str()); + oss.str(""); + if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen && + hitBlocks.size() != (topkLen - mExtraTopkLen)) { + mLogger.log(LogLevel::ERROR, + "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: " + "%lu, miss size: %lu , topkLen: %d, not equal error\n", + mDecodeStep, mRank, reqID, layerID, hitBlocks.size(), missIdxs.size(), topkLen); + PrintMap(reqID, layerID); + } +} - void GSAPrefetchEngineC::PrintMap(std::string reqID, int i) - { - std::ostringstream oss; - oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " - << reqID << " layerID: " << i << "mDocsTables"; - for (auto it : mDocsTables[reqID][i]) { - oss << "(" << it.first << ", " << it.second << ")"; - } - oss << "------\n"; - mLogger.log(LogLevel::INFO, oss.str().c_str()); - oss.str(""); - oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " - << reqID << " layerID: " << i << "mBlocksMap"; - for (auto it : mBlocksMap[reqID][i]) { - oss << "(" << it.first << ", " << it.second << ")"; +void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, + std::unordered_set& hitBlocks, + std::map& hitBlocksIdx, + std::vector& missIdxs) +{ + int layerID = oneBsInfo.layerID; + std::string reqID = oneBsInfo.reqID; + uint32_t topkLen = oneBsInfo.topkLen; + int bsIndex = oneBsInfo.bsIndex; + + int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item(); + int* freeBlockPtr = mFreeBlock[layerID][bsIndex].data_ptr(); + std::vector oneFreeBlockTable; + + uint32_t index = 0; + int oneFreeBlockIndex = 0; + while (oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size() && + hitBlocks.size() < (topkLen - mExtraTopkLen)) { + int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex]; + if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) { + oneFreeBlockIndex += 1; + continue; + } else { + oneFreeBlockTable.push_back(oneFreeBlockID); + hitBlocks.insert(oneFreeBlockID); + hitBlocksIdx.insert(std::make_pair(missIdxs[index], oneFreeBlockID)); + index += 1; + oneFreeBlockIndex += 1; } - oss << "------\n"; - mLogger.log(LogLevel::INFO, oss.str().c_str()); - oss.str(""); } - - void GSAPrefetchEngineC::GetHitAndMissBlock(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs) - { - int topkLen = oneBsInfo.topkLen; - int layerID = oneBsInfo.layerID; - std::string reqID = oneBsInfo.reqID; - int topkIndex = oneBsInfo.topkIndex; - - std::ostringstream oss; - oss << "Decode step: " << mDecodeStep << " Rnak: " << mRank << " reqID: " - << reqID << " layerID: " << layerID << " topk len: " << topkLen << " topk: "; - for (int j = 0; j < topkLen; j++) { - int64_t item = 0; - if (mUseTopkIdxs.scalar_type() == torch::kInt32) { - item = mUseTopkIdxs[layerID][topkIndex][j].item(); - } else { - item = mUseTopkIdxs[layerID][topkIndex][j].item(); - } - oss << item << " "; - if (mDocsTables[reqID][layerID].find(item) != mDocsTables[reqID][layerID].end()) { - int blockID = mDocsTables[reqID][layerID][item]; - hitBlocks.insert(blockID); - hitBlocksIdx.insert(std::make_pair(item, blockID)); - if (hitBlocks.size() == (topkLen - mExtraTopkLen)) { - break; - } - } else { - missIdxs.push_back(item); - } + uint32_t loadLen = oneFreeBlockTable.size(); + missIdxs.erase(missIdxs.begin() + loadLen, missIdxs.end()); + allNeedLoadBlock[reqID][layerID] = oneFreeBlockTable; + allMissIdxs[reqID][layerID] = missIdxs; + LoadKVToHBM(oneFreeBlockTable, missIdxs, layerID, reqID); +} + +void GSAPrefetchEngineC::RunOneBsPrefetch(std::string reqID, int topkLen, int bsIndex, + int topkIndex) +{ +#pragma omp parallel for num_threads(16) proc_bind(master) + for (int i = 0; i < mLayerNum; i++) { + mLoadSuccessBlocks[i][bsIndex].fill_(0); + int* freeBlockPtr = mFreeBlock[i][bsIndex].data_ptr(); + std::unordered_set hitBlocks; + std::map hitBlocksIdx; + std::vector missIdxs; + PrefetchReqInfo oneBsInfo; + oneBsInfo.topkLen = topkLen; + oneBsInfo.reqID = reqID; + oneBsInfo.topkIndex = topkIndex; + oneBsInfo.bsIndex = bsIndex; + oneBsInfo.layerID = i; + GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); + if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { + RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); } - oss << "------\n"; - mLogger.log(LogLevel::DEBUG, oss.str().c_str()); - oss.str(""); - if ((hitBlocks.size() + missIdxs.size()) != (uint32_t)topkLen && hitBlocks.size() != (topkLen - mExtraTopkLen)) { - mLogger.log(LogLevel::ERROR, - "|KVCache Prefetch| Decode step: %u, Rank: %d, reqID: %s, layer: %d, hit size: %lu, miss size: %lu , topkLen: %d, not equal error\n", - mDecodeStep, mRank, reqID, layerID, hitBlocks.size(), missIdxs.size(), topkLen); - PrintMap(reqID, layerID); + int successIndex = 0; + for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) { + mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second; + successIndex += 1; } - - } - - void GSAPrefetchEngineC::RunPrefetchH2D(PrefetchReqInfo oneBsInfo, - std::unordered_set &hitBlocks, - std::map &hitBlocksIdx, - std::vector &missIdxs) - { - int layerID = oneBsInfo.layerID; - std::string reqID = oneBsInfo.reqID; - uint32_t topkLen = oneBsInfo.topkLen; - int bsIndex = oneBsInfo.bsIndex; - - int oneFreeBlockLen = mFreeBlockLen[layerID][bsIndex].item(); - int *freeBlockPtr = mFreeBlock[layerID][bsIndex].data_ptr(); - std::vector oneFreeBlockTable; - - uint32_t index = 0; int oneFreeBlockIndex = 0; - while(oneFreeBlockIndex < oneFreeBlockLen && index < missIdxs.size() && hitBlocks.size() < (topkLen - mExtraTopkLen)) { - int oneFreeBlockID = freeBlockPtr[oneFreeBlockIndex]; - if (hitBlocks.find(oneFreeBlockID) != hitBlocks.end()) { - oneFreeBlockIndex += 1; + for (auto it = mDocsTables[reqID][i].begin(); it != mDocsTables[reqID][i].end(); it++) { + if (it->first >= mPromptLen[reqID]) { break; } + if (hitBlocksIdx.find(it->first) != hitBlocksIdx.end()) { continue; } else { - oneFreeBlockTable.push_back(oneFreeBlockID); - hitBlocks.insert(oneFreeBlockID); - hitBlocksIdx.insert(std::make_pair(missIdxs[index], oneFreeBlockID)); - index += 1; + freeBlockPtr[oneFreeBlockIndex] = it->second; oneFreeBlockIndex += 1; } } - uint32_t loadLen = oneFreeBlockTable.size(); - missIdxs.erase(missIdxs.begin() + loadLen, missIdxs.end()); - allNeedLoadBlock[reqID][layerID] = oneFreeBlockTable; - allMissIdxs[reqID][layerID] = missIdxs; - LoadKVToHBM(oneFreeBlockTable, missIdxs, layerID, reqID); + mFreeBlockLen[i][bsIndex] = oneFreeBlockIndex; + mSuccessTableLen[i][bsIndex] = (int)(hitBlocks.size()); } +} - void GSAPrefetchEngineC::RunOneBsPrefetch(std::string reqID, - int topkLen, int bsIndex, int topkIndex) - { -#pragma omp parallel for num_threads(16) proc_bind(master) - for (int i = 0; i < mLayerNum; i++) { - mLoadSuccessBlocks[i][bsIndex].fill_(0); - int *freeBlockPtr = mFreeBlock[i][bsIndex].data_ptr(); - std::unordered_set hitBlocks; - std::map hitBlocksIdx; - std::vector missIdxs; - PrefetchReqInfo oneBsInfo; - oneBsInfo.topkLen = topkLen; - oneBsInfo.reqID = reqID; - oneBsInfo.topkIndex = topkIndex; - oneBsInfo.bsIndex = bsIndex; - oneBsInfo.layerID = i; - GetHitAndMissBlock(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); - if (missIdxs.size() != 0 && hitBlocksIdx.size() < (topkLen - mExtraTopkLen)) { - RunPrefetchH2D(oneBsInfo, hitBlocks, hitBlocksIdx, missIdxs); +void GSAPrefetchEngineC::LoadKVToHBM(std::vector loadNPUBlockIDs, std::vector missIdxs, + int layerID, std::string reqID) +{ + for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { + if (!mIsPythonLoad) { + if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { + mLogger.log(LogLevel::INFO, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", + mDecodeStep, mRank, reqID.c_str(), layerID); + return; } - int successIndex = 0; - for (auto it = hitBlocksIdx.begin(); it != hitBlocksIdx.end(); it++) { - mLoadSuccessBlocks[i][bsIndex][successIndex] = it->second; - successIndex += 1; + while (mStopPrefetch) { std::this_thread::sleep_for(std::chrono::microseconds(2)); } + UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; + std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; + size_t kOffset = GetOffsetNew(layerID, false); + size_t vOffset = GetOffsetNew(layerID, true); + if (!mUseMla) { + task.Append(blockId, kOffset, + reinterpret_cast( + mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + task.Append(blockId, vOffset, + reinterpret_cast( + mKvCaches[layerID][1][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); + } else { + task.Append( + blockId, kOffset, + reinterpret_cast(mKvCaches[layerID][loadNPUBlockIDs[i]].data_ptr()), + mKVSzieBytes); } - int oneFreeBlockIndex = 0; - for (auto it = mDocsTables[reqID][i].begin(); it != mDocsTables[reqID][i].end(); it++) { - if (it->first >= mPromptLen[reqID]) { - break; - } - if (hitBlocksIdx.find(it->first) != hitBlocksIdx.end()) { - continue; - } else { - freeBlockPtr[oneFreeBlockIndex] = it->second; - oneFreeBlockIndex += 1; - } + size_t taskID = mStore->Submit(std::move(task)); + auto ret = mStore->Wait(taskID); + if (ret != 0) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, Rank: %d, reqID: %s, layer: %d, blockID: %lu, miss " + "idx: %u, load blockid: %u load k error\n", + mDecodeStep, mRank, reqID.c_str(), layerID, blockId, missIdxs[i], + loadNPUBlockIDs[i]); + return; } - mFreeBlockLen[i][bsIndex] = oneFreeBlockIndex; - mSuccessTableLen[i][bsIndex] = (int)(hitBlocks.size()); } - } - - void GSAPrefetchEngineC::LoadKVToHBM(std::vector loadNPUBlockIDs, - std::vector missIdxs, int layerID, std::string reqID) - { - for (size_t i = 0; i < loadNPUBlockIDs.size(); i++) { - if (!mIsPythonLoad) { - if (mDelSeqIds.find(reqID) != mDelSeqIds.end()) { - mLogger.log(LogLevel::INFO, - "Decode step: %u, Rank: %d, reqID: %s, layer: %d, stop prefetch\n", - mDecodeStep, mRank, reqID.c_str(), layerID); - return; - } - while (mStopPrefetch) { - std::this_thread::sleep_for(std::chrono::microseconds(2)); - } - UC::Task task{UC::Task::Type::LOAD, UC::Task::Location::DEVICE, "NFS::S2D"}; - std::string blockId = mAllBlcoksHash[reqID][missIdxs[i]]; - size_t kOffset = GetOffsetNew(layerID, false); - size_t vOffset = GetOffsetNew(layerID, true); - if (!mUseMla) { - task.Append(blockId, kOffset, - reinterpret_cast(mKvCaches[layerID][0][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - task.Append(blockId, vOffset, - reinterpret_cast(mKvCaches[layerID][1][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - } else { - task.Append(blockId, kOffset, - reinterpret_cast(mKvCaches[layerID][loadNPUBlockIDs[i]].data_ptr()), - mKVSzieBytes); - } - size_t taskID = mStore->Submit(std::move(task)); - auto ret = mStore->Wait(taskID); - if (ret != 0) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, Rank: %d, reqID: %s, layer: %d, blockID: %lu, miss idx: %u, load blockid: %u load k error\n", - mDecodeStep, mRank, reqID.c_str(), layerID, blockId, missIdxs[i], loadNPUBlockIDs[i]); - return; - } - } - int oriIdx = mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]]; - mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]] = missIdxs[i]; - mDocsTables[reqID][layerID].erase(oriIdx); - mDocsTables[reqID][layerID][missIdxs[i]] = loadNPUBlockIDs[i]; - } + int oriIdx = mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]]; + mBlocksMap[reqID][layerID][loadNPUBlockIDs[i]] = missIdxs[i]; + mDocsTables[reqID][layerID].erase(oriIdx); + mDocsTables[reqID][layerID][missIdxs[i]] = loadNPUBlockIDs[i]; } +} - void GSAPrefetchEngineC::RunAsyncPrefetchBs(std::vector &reqIDsInput, - std::vector &topkLensInput, - std::vector &bsIndexInput, - std::vector &kvCaches, - void *storePtr) - { - if (mKVSzieBytes == 0) { - mTensorElemSize = kvCaches[0].element_size(); - if (mUseMla) { - mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0].numel(); - } else { - mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0][0].numel(); - } - if (storePtr == nullptr) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| storePtr is nullptr error\n", - mDecodeStep); - std::abort(); - } - mStore = static_cast *>(storePtr); - mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| start mKVSzieBytes: %u, mTensorElemSize %u, store %p\n", - mDecodeStep, mKVSzieBytes, mTensorElemSize, mStore); +void GSAPrefetchEngineC::RunAsyncPrefetchBs(std::vector& reqIDsInput, + std::vector& topkLensInput, + std::vector& bsIndexInput, + std::vector& kvCaches, void* storePtr) +{ + if (mKVSzieBytes == 0) { + mTensorElemSize = kvCaches[0].element_size(); + if (mUseMla) { + mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0].numel(); + } else { + mKVSzieBytes = kvCaches[0].element_size() * kvCaches[0][0][0].numel(); } - mKvCaches = kvCaches; - mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| start async pretch batch size: %lu\n", - mDecodeStep, reqIDsInput.size()); - runBsLen = reqIDsInput.size(); - if (runBsLen > mMaxBs) { + if (storePtr == nullptr) { mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| runBsLen %u, maxBs: %d\n", - mDecodeStep, runBsLen, mMaxBs); + "Decode step: %u, |KVCache Prefetch| storePtr is nullptr error\n", + mDecodeStep); std::abort(); } - mReqIdList.clear(); - mReqIdList.assign(reqIDsInput.begin(), reqIDsInput.end()); - memcpy(mTopkLenList, topkLensInput.data(), sizeof(int) * runBsLen); - memcpy(mBsIndexList, bsIndexInput.data(), sizeof(int) * runBsLen); - mMutex.lock(); - mIsPrefetchDone = false; - mMutex.unlock(); - if (mIsPythonLoad) { - MutliBSThreadFun(this); - } else { - mThreadPool->enqueue(MutliBSThreadFun, this); - } + mStore = static_cast*>(storePtr); + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| start mKVSzieBytes: %u, mTensorElemSize " + "%u, store %p\n", + mDecodeStep, mKVSzieBytes, mTensorElemSize, mStore); } - - void GSAPrefetchEngineC::SetBlockTableInfo(torch::Tensor &blockTables, torch::Tensor &blockLengths, - torch::Tensor &inputTopkBuf, int step) - { - mLoadSuccessBlocks = blockTables; - mSuccessTableLen = blockLengths; - mUseTopkIdxs = inputTopkBuf.clone(); - mDecodeStep = step; + mKvCaches = kvCaches; + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| start async pretch batch size: %lu\n", + mDecodeStep, reqIDsInput.size()); + runBsLen = reqIDsInput.size(); + if (runBsLen > mMaxBs) { + mLogger.log(LogLevel::ERROR, "Decode step: %u, |KVCache Prefetch| runBsLen %u, maxBs: %d\n", + mDecodeStep, runBsLen, mMaxBs); + std::abort(); + } + mReqIdList.clear(); + mReqIdList.assign(reqIDsInput.begin(), reqIDsInput.end()); + memcpy(mTopkLenList, topkLensInput.data(), sizeof(int) * runBsLen); + memcpy(mBsIndexList, bsIndexInput.data(), sizeof(int) * runBsLen); + mMutex.lock(); + mIsPrefetchDone = false; + mMutex.unlock(); + if (mIsPythonLoad) { + MutliBSThreadFun(this); + } else { + mThreadPool->enqueue(MutliBSThreadFun, this); } +} +void GSAPrefetchEngineC::SetBlockTableInfo(torch::Tensor& blockTables, torch::Tensor& blockLengths, + torch::Tensor& inputTopkBuf, int step) +{ + mLoadSuccessBlocks = blockTables; + mSuccessTableLen = blockLengths; + mUseTopkIdxs = inputTopkBuf.clone(); + mDecodeStep = step; +} - int GSAPrefetchEngineC::CallPrefetchProcessFun() - { - auto start = std::chrono::high_resolution_clock::now(); - allNeedLoadBlock.clear(); - allMissIdxs.clear(); - for (size_t i = 0; i < runBsLen; i++) { - if (mDocsTables.find(mReqIdList[i]) == mDocsTables.end() || mTopkLenList[i] <= 0) { - mLogger.log(LogLevel::ERROR, - "Decode step: %u, |KVCache Prefetch| topk len is zero: %d\n", - mDecodeStep, mTopkLenList[i]); - continue; - } - allMissIdxs.insert({mReqIdList[i], std::vector>(mLayerNum)}); - allNeedLoadBlock.insert({mReqIdList[i], std::vector>(mLayerNum)}); - RunOneBsPrefetch(mReqIdList[i], mTopkLenList[i], mBsIndexList[i], i); +int GSAPrefetchEngineC::CallPrefetchProcessFun() +{ + auto start = std::chrono::high_resolution_clock::now(); + allNeedLoadBlock.clear(); + allMissIdxs.clear(); + for (size_t i = 0; i < runBsLen; i++) { + if (mDocsTables.find(mReqIdList[i]) == mDocsTables.end() || mTopkLenList[i] <= 0) { + mLogger.log(LogLevel::ERROR, + "Decode step: %u, |KVCache Prefetch| topk len is zero: %d\n", mDecodeStep, + mTopkLenList[i]); + continue; } - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - mLogger.log(LogLevel::INFO, - "Decode step: %u, |KVCache Prefetch| Finish async pretch cost: %lu\n", - mDecodeStep, duration.count()); - return 0; + allMissIdxs.insert({mReqIdList[i], std::vector>(mLayerNum)}); + allNeedLoadBlock.insert({mReqIdList[i], std::vector>(mLayerNum)}); + RunOneBsPrefetch(mReqIdList[i], mTopkLenList[i], mBsIndexList[i], i); } + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + mLogger.log(LogLevel::INFO, + "Decode step: %u, |KVCache Prefetch| Finish async pretch cost: %lu\n", mDecodeStep, + duration.count()); + return 0; +} - bool GSAPrefetchEngineC::GetPrefetchStatus() - { - return mIsPrefetchDone; - } +bool GSAPrefetchEngineC::GetPrefetchStatus() { return mIsPrefetchDone; } - void GSAPrefetchEngineC::SetPrefetchStatus(bool flag) - { - mMutex.lock(); - mIsPrefetchDone = flag; - mMutex.unlock(); - } +void GSAPrefetchEngineC::SetPrefetchStatus(bool flag) +{ + mMutex.lock(); + mIsPrefetchDone = flag; + mMutex.unlock(); +} - void GSAPrefetchEngineC::SetModelRunningStatus(bool flag) - { - mStopPrefetch = flag; - } +void GSAPrefetchEngineC::SetModelRunningStatus(bool flag) { mStopPrefetch = flag; } - std::map>> GSAPrefetchEngineC::ObtainLoadBlocks() - { - return allNeedLoadBlock; - } +std::map>> GSAPrefetchEngineC::ObtainLoadBlocks() +{ + return allNeedLoadBlock; +} - std::map>> GSAPrefetchEngineC::ObtainMissIdxs() - { - return allMissIdxs; - } +std::map>> GSAPrefetchEngineC::ObtainMissIdxs() +{ + return allMissIdxs; +} - std::map>> GSAPrefetchEngineC::ObtainBlocksMap() - { - return mBlocksMap; - } +std::map>> GSAPrefetchEngineC::ObtainBlocksMap() +{ + return mBlocksMap; +} - std::map>> GSAPrefetchEngineC::ObtainDocsMap() - { - return mDocsTables; - } -} // namespace uc +std::map>> GSAPrefetchEngineC::ObtainDocsMap() +{ + return mDocsTables; +} +} // namespace ucmprefetch diff --git a/ucm/sparse/gsa/prefetch/src/pybinds.cpp b/ucm/sparse/gsa/prefetch/src/pybinds.cpp index decd3895..25a1f5d5 100644 --- a/ucm/sparse/gsa/prefetch/src/pybinds.cpp +++ b/ucm/sparse/gsa/prefetch/src/pybinds.cpp @@ -1,30 +1,29 @@ #pragma GCC diagnostic push -#include -#include #include +#include #include +#include #pragma GCC diagnostic pop #include "kvcache_pre.h" -namespace ucmprefetch{ - PYBIND11_MODULE(gsa_prefetch, m) - { - pybind11::class_(m, "GSAPrefetchEngineC") - .def(pybind11::init &, - bool, bool, int, int, int, bool>()) - .def("set_blocks_map", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMap) - .def("set_blocks_map_multilayer", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMapMultiLayer) - .def("add_blocks_map", &ucmprefetch::GSAPrefetchEngineC::AddBlocksMap) - .def("del_blocks_map", &ucmprefetch::GSAPrefetchEngineC::DelBlocksMap) - .def("run_async_prefetch_bs", &ucmprefetch::GSAPrefetchEngineC::RunAsyncPrefetchBs) - .def("set_blocks_table_info", &ucmprefetch::GSAPrefetchEngineC::SetBlockTableInfo) - .def("get_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::GetPrefetchStatus) - .def("set_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::SetPrefetchStatus) - .def("set_modelrunning_status", &ucmprefetch::GSAPrefetchEngineC::SetModelRunningStatus) - .def("obtain_load_blocks", &ucmprefetch::GSAPrefetchEngineC::ObtainLoadBlocks) - .def("obtain_miss_idxs", &ucmprefetch::GSAPrefetchEngineC::ObtainMissIdxs) - .def("obtain_docs_map", &ucmprefetch::GSAPrefetchEngineC::ObtainDocsMap) - .def("obtain_blocks_map", &ucmprefetch::GSAPrefetchEngineC::ObtainBlocksMap); - } +namespace ucmprefetch { +PYBIND11_MODULE(gsa_prefetch, m) +{ + pybind11::class_(m, "GSAPrefetchEngineC") + .def(pybind11::init&, bool, bool, int, int, int, bool>()) + .def("set_blocks_map", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMap) + .def("set_blocks_map_multilayer", &ucmprefetch::GSAPrefetchEngineC::SetBlocksMapMultiLayer) + .def("add_blocks_map", &ucmprefetch::GSAPrefetchEngineC::AddBlocksMap) + .def("del_blocks_map", &ucmprefetch::GSAPrefetchEngineC::DelBlocksMap) + .def("run_async_prefetch_bs", &ucmprefetch::GSAPrefetchEngineC::RunAsyncPrefetchBs) + .def("set_blocks_table_info", &ucmprefetch::GSAPrefetchEngineC::SetBlockTableInfo) + .def("get_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::GetPrefetchStatus) + .def("set_prefetch_status", &ucmprefetch::GSAPrefetchEngineC::SetPrefetchStatus) + .def("set_modelrunning_status", &ucmprefetch::GSAPrefetchEngineC::SetModelRunningStatus) + .def("obtain_load_blocks", &ucmprefetch::GSAPrefetchEngineC::ObtainLoadBlocks) + .def("obtain_miss_idxs", &ucmprefetch::GSAPrefetchEngineC::ObtainMissIdxs) + .def("obtain_docs_map", &ucmprefetch::GSAPrefetchEngineC::ObtainDocsMap) + .def("obtain_blocks_map", &ucmprefetch::GSAPrefetchEngineC::ObtainBlocksMap); } +} // namespace ucmprefetch From 2399db86ced4fc9cf20b22f8654929a8b7369fe1 Mon Sep 17 00:00:00 2001 From: zbb200819 <1130072360@qq.com> Date: Mon, 1 Dec 2025 06:03:58 -0800 Subject: [PATCH 8/8] cleancode --- ucm/sparse/gsa/prefetch/include/kvcache_pre.h | 2 +- ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h index 0a8e0b02..1ee9952f 100644 --- a/ucm/sparse/gsa/prefetch/include/kvcache_pre.h +++ b/ucm/sparse/gsa/prefetch/include/kvcache_pre.h @@ -47,7 +47,7 @@ class ThreadPool { ~ThreadPool(); template - auto enqueue(F&& f, Args&&... args) -> std::future::type>; + auto Enqueue(F&& f, Args&&... args) -> std::future::type>; size_t GetActiveThreads() const; diff --git a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp index d7670afa..8c0bde87 100644 --- a/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp +++ b/ucm/sparse/gsa/prefetch/src/kvcache_pre.cpp @@ -43,7 +43,7 @@ ThreadPool::~ThreadPool() } template -auto ThreadPool::enqueue(F&& f, Args&&... args) +auto ThreadPool::Enqueue(F&& f, Args&&... args) -> std::future::type> { using return_type = typename std::result_of::type; @@ -509,7 +509,7 @@ void GSAPrefetchEngineC::RunAsyncPrefetchBs(std::vector& reqIDsInpu if (mIsPythonLoad) { MutliBSThreadFun(this); } else { - mThreadPool->enqueue(MutliBSThreadFun, this); + mThreadPool->Enqueue(MutliBSThreadFun, this); } }