From 343e8098732a1ccd6f4f64bfba7bd35e43087762 Mon Sep 17 00:00:00 2001 From: wenxinwang Date: Mon, 1 Dec 2025 12:51:27 +0800 Subject: [PATCH 1/2] sparse to adapt new connector --- examples/offline_inference_esa.py | 21 ++-- examples/offline_inference_kvcomp.py | 18 ++-- examples/offline_inference_kvstar.py | 18 ++-- ucm/sparse/esa/esa.py | 149 ++++++++++++++++----------- ucm/sparse/kvcomp/kvcomp.py | 5 +- ucm/sparse/kvstar/multistep.py | 136 +++++++++++++++--------- ucm/sparse/kvstar/utils.py | 26 +++++ 7 files changed, 245 insertions(+), 128 deletions(-) diff --git a/examples/offline_inference_esa.py b/examples/offline_inference_esa.py index c420e9b9..88dc0198 100644 --- a/examples/offline_inference_esa.py +++ b/examples/offline_inference_esa.py @@ -66,12 +66,19 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector=name, kv_connector_module_path=module_path, kv_role="kv_both", + # kv_connector_extra_config={ + # "UCM_CONFIG_FILE": "/home/externals/wangwenxin21/va_new/unified-cache-management/examples/ucm_config_example.yaml" + # }, kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": data_dir, - "kv_block_size": 33554432, - }, + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], "ucm_sparse_config": { "ESA": { "init_window_sz": 1, @@ -125,8 +132,8 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" setup_environment_variables() def get_prompt(prompt): diff --git a/examples/offline_inference_kvcomp.py b/examples/offline_inference_kvcomp.py index 6aa2ed31..f1cd6c19 100644 --- a/examples/offline_inference_kvcomp.py +++ b/examples/offline_inference_kvcomp.py @@ -67,11 +67,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": data_dir, - "kv_block_size": 33554432, - }, + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], "ucm_sparse_config": { "KvComp": { "init_window_sz": 1, @@ -123,8 +127,8 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" setup_environment_variables() def get_prompt(prompt): diff --git a/examples/offline_inference_kvstar.py b/examples/offline_inference_kvstar.py index 70217542..c8dfd1ee 100644 --- a/examples/offline_inference_kvstar.py +++ b/examples/offline_inference_kvstar.py @@ -68,11 +68,15 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmNfsStore", - "ucm_connector_config": { - "storage_backends": data_dir, - "kv_block_size": 33554432, - }, + "ucm_connectors": [ + { + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "use_direct": False, + }, + } + ], "ucm_sparse_config": { "KVStarMultiStep": { "init_window_sz": 1, @@ -123,8 +127,8 @@ def print_output( def main(): - module_path = "ucm.integration.vllm.uc_connector" - name = "UnifiedCacheConnectorV1" + module_path = "ucm.integration.vllm.ucm_connector" + name = "UCMConnector" setup_environment_variables() def get_prompt(prompt): diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index d8316cc6..fc6226cd 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -16,6 +16,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request, RequestStatus +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -61,6 +62,7 @@ class ReqMeta: prompt_token_ids: list[int] output_token_ids: list[int] is_preempt: bool + ucm_block_hashes: list[str] @property def num_prompt_tokens(self) -> int: @@ -100,6 +102,7 @@ def add_request( prompt_token_ids: list[int], output_token_ids: list[int], is_preempt: bool, + ucm_block_hashes: list[str], ) -> None: meta = ReqMeta( @@ -112,6 +115,7 @@ def add_request( prompt_token_ids=prompt_token_ids, output_token_ids=output_token_ids, is_preempt=is_preempt, + ucm_block_hashes=ucm_block_hashes, ) self.requests.append(meta) @@ -140,18 +144,29 @@ def get_sparse_range(init_window_sz, local_window_sz, prompt_len, block_size): @cache -def md5(input) -> int: - input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - md5_bytes = hashlib.md5(input_bytes).digest() - return int.from_bytes(md5_bytes, byteorder="big") +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") @cache -def block_hash_func(parent_block_hash, curr_block_token_ids): - if not parent_block_hash: - parent_block_hash = md5("UCMHASHSEED") - curr_block_token_ids_tuple = tuple(curr_block_token_ids) - return md5((parent_block_hash, curr_block_token_ids_tuple)) +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset def task_hash_func(block_ids, store_type, tensor_type): @@ -178,7 +193,6 @@ def diff_two_map(map1: dict, map2: dict): class ReqStatePerLayer: # handle single request per layer - def __init__( self, layer_name: str, @@ -223,49 +237,22 @@ def __init__( self.is_mla = self.vllm_config.model_config.is_deepseek_mla self.step = 0 - def set_block_hashes(self, token_ids): - if self.block_hashes is not None: - return - self.block_hashes = [] - parent_block_hash_value = None - num_total_blocks = math.ceil(len(token_ids) / self.block_size) - for start in range(0, len(token_ids), self.block_size): - end = start + self.block_size - block_idx = start // self.block_size - if block_idx >= num_total_blocks - self.esa_cfg["local_window_sz"]: - continue - block_token_ids = token_ids[start:end] - if len(block_token_ids) < self.block_size: - break - curr_block_token_ids_tuple = tuple(block_token_ids) - block_hash = block_hash_func( - parent_block_hash_value, curr_block_token_ids_tuple - ) - if block_idx >= self.esa_cfg["init_window_sz"]: - self.block_hashes.append(str(block_hash)) - parent_block_hash_value = block_hash - def update_meta(self, req_meta: ReqMeta): self.req_meta = req_meta def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): fn = getattr(self.store_instance, transfer_type) length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) precision = self.vllm_config.model_config.dtype.itemsize + block_data_size = self.k_cache[0].numel() * precision - block_shape = tuple(block_shape) - offsets_k = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=False, - is_mla=self.is_mla, - ) - ] * length + offset_k = compute_layer_offset( + block_data_size, + self.layer_id, + is_v=False, + is_mla=self.is_mla, + ) + offsets_k = [offset_k] * length key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] task_k = fn(block_hashes, offsets_k, key_src_tensors) @@ -273,17 +260,13 @@ def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): self.tasks[task_k_hash] = task_k if not self.is_mla: - offsets_v = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=True, - is_mla=self.is_mla, - ) - ] * length + offset_v = compute_layer_offset( + block_data_size, + self.layer_id, + is_v=True, + is_mla=self.is_mla, + ) + offsets_v = [offset_v] * length value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] task_v = fn(block_hashes, offsets_v, value_src_tensors) task_v_hash = task_hash_func(block_hashes, transfer_type, "value") @@ -303,7 +286,7 @@ def maybe_register_static_data(self, forward_context: ForwardContext): else: self.k_cache = kv_cache[0] self.v_cache = kv_cache[1] - self.set_block_hashes(self.req_meta.prompt_token_ids) + self.block_hashes = self.req_meta.ucm_block_hashes self.init_static_flag = True def wait_transfer_task_done(self): @@ -470,7 +453,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size if role == UcmSparseRole.WORKER: - self.connector = get_kv_transfer_group().connector + self.connector = get_kv_transfer_group().connector.store else: self.connector = None self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ @@ -483,6 +466,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData() self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData() self._sparse_metadata: ESASparseMetaData = ESASparseMetaData() + self.request_hasher = RequestHasher(vllm_config, 0) + self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} global data if data is None: @@ -601,7 +587,6 @@ def attention_finished( forward_context: ForwardContext, phase: Optional[str] = None, ) -> None: - if not self.is_mla: for req_meta in self._sparse_metadata.requests: self.update_req_state_attention_end( @@ -643,6 +628,47 @@ def is_sparsed_request(self, req): >= self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"] ) + def set_block_hashes(self, req_id, token_ids): + if req_id not in self.block_hashes: + self.block_hashes[req_id] = {} + + if self.rank in self.block_hashes[req_id]: + return + + self.block_hashes[req_id][self.rank] = [] + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + + num_total_blocks = math.ceil(len(token_ids) / self.block_size) + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + block_idx = start // self.block_size + if block_idx >= num_total_blocks - self.esa_cfg["local_window_sz"]: + continue + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) + ) + if block_idx >= self.esa_cfg["init_window_sz"]: + self.block_hashes[req_id][self.rank].append(str(hash_value)) + + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.is_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes[req_id][self.rank]): + self.block_hashes[req_id][self.rank][i] = str( + self.newqrequest_hasher(ucm_block_id) + ) + def build_sparse_meta( self, scheduler_output, requests, input_batch, attn_metadata ) -> UcmSparseMetadata: @@ -654,7 +680,6 @@ def build_sparse_meta( req_ids = list(getattr(input_batch, "req_ids", [])) decode_ids = [rid for rid in req_ids if num_sched.get(rid, 0) == 1] decode_set = set(decode_ids) - cached_reqs = scheduler_output.scheduled_cached_reqs preempt_reqs = set() if cached_reqs: @@ -670,6 +695,7 @@ def build_sparse_meta( req = requests[req_id] if not self.is_sparsed_request(req): continue + self.set_block_hashes(int(req_id), req.prompt_token_ids) if isinstance(attn_metadata, dict): attn_metadata = next(iter(attn_metadata.values())) @@ -684,6 +710,7 @@ def build_sparse_meta( req.prompt_token_ids, req.output_token_ids, req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], ) else: @@ -704,6 +731,7 @@ def build_sparse_meta( req.prompt_token_ids, req.output_token_ids, req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], ) else: @@ -720,6 +748,7 @@ def build_sparse_meta( req.prompt_token_ids, req.output_token_ids, req_id in preempt_reqs, + self.block_hashes[int(req_id)][self.rank], ) # self._sparse_metadata = sparse_meta diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index 8a1f6123..b2483c73 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -10,6 +10,7 @@ from vllm.forward_context import ForwardContext from vllm.v1.request import Request, RequestStatus +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.logger import init_logger from ucm.sparse.base import ( INVALID_SLOT, @@ -151,7 +152,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size if role == UcmSparseRole.WORKER: - self.connector = get_kv_transfer_group().connector + self.connector = get_kv_transfer_group().connector.store else: self.connector = None self.total_num_hidden_layers = ( @@ -166,6 +167,8 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): ]["KvComp"] self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} + self.request_hasher = RequestHasher(vllm_config, 0) self.num_kv_heads = vllm_config.model_config.get_num_kv_heads( vllm_config.parallel_config ) diff --git a/ucm/sparse/kvstar/multistep.py b/ucm/sparse/kvstar/multistep.py index 39e52fba..79b50e15 100644 --- a/ucm/sparse/kvstar/multistep.py +++ b/ucm/sparse/kvstar/multistep.py @@ -10,6 +10,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +from ucm.integration.vllm.ucm_connector import RequestHasher from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, @@ -17,7 +18,12 @@ UcmSparseRole, ) from ucm.sparse.kvstar.retrieve import kvstar_retrieve -from ucm.sparse.kvstar.utils import block_hash_func, get_bind_cpus_for_rank, get_offset +from ucm.sparse.kvstar.utils import ( + block_hash_func, + compute_layer_offset, + compute_parent_block_hash, + get_bind_cpus_for_rank, +) from ucm.store.ucmstore import Task, UcmKVStoreBase """ @@ -57,28 +63,28 @@ class ReqMeta: retrieval_stride: int = 8 block_hashes: list[str] = field(default_factory=list) - def set_block_hashes(self, token_ids): - block_hashes = [] - parent_block_hash_value = None - for start in range(0, len(token_ids), self.token_blk_size): - end = start + self.token_blk_size - block_token_ids = token_ids[start:end] - if len(block_token_ids) < self.token_blk_size: - break - curr_block_token_ids_tuple = tuple(block_token_ids) - block_hash = block_hash_func( - parent_block_hash_value, curr_block_token_ids_tuple - ) - block_hashes.append(str(block_hash)) - parent_block_hash_value = block_hash - return block_hashes - - @property - def req_block_hashes(self) -> list[str]: - if self.block_hashes: - return self.block_hashes - self.block_hashes = self.set_block_hashes(self.prompt_token_ids) - return self.block_hashes + # def set_block_hashes(self, token_ids): + # block_hashes = [] + # parent_block_hash_value = None + # for start in range(0, len(token_ids), self.token_blk_size): + # end = start + self.token_blk_size + # block_token_ids = token_ids[start:end] + # if len(block_token_ids) < self.token_blk_size: + # break + # curr_block_token_ids_tuple = tuple(block_token_ids) + # block_hash = block_hash_func( + # parent_block_hash_value, curr_block_token_ids_tuple + # ) + # block_hashes.append(str(block_hash)) + # parent_block_hash_value = block_hash + # return block_hashes + + # @property + # def req_block_hashes(self) -> list[str]: + # if self.block_hashes: + # return self.block_hashes + # self.block_hashes = self.set_block_hashes(self.prompt_token_ids) + # return self.block_hashes @property def step(self) -> int: @@ -153,6 +159,7 @@ def add_request( query_len: int, retrieval_stride: int, prompt_token_ids: list[int], + ucm_block_hashes: list[str], ) -> None: meta = ReqMeta( request_id=request_id, @@ -168,6 +175,7 @@ def add_request( query_start_loc=query_start_loc, query_len=query_len, retrieval_stride=retrieval_stride, + block_hashes=ucm_block_hashes, ) self.requests.append(meta) @@ -181,7 +189,6 @@ def __init__( rank: int, tp_size: int, store_instance: UcmKVStoreBase, - store_name: str, sparse_cfg, ): self.sparse_cfg = sparse_cfg @@ -193,7 +200,6 @@ def __init__( self.num_tokens = 0 # the number of all_tokens, prompt+output self.store_instance = store_instance - self.store_name = store_name self.req_meta = req_meta self.init_window: tuple[torch.Tensor, torch.Tensor] = None self.local_window: tuple[torch.Tensor, torch.Tensor] = None @@ -577,7 +583,7 @@ def maybe_register_kv_cache(self, forward_context: ForwardContext): self.v_cache = kv_cache[1] self.block_size = self.k_cache.shape[1] self.num_key_heads = self.k_cache.shape[2] - self.block_hashes = self.req_meta.req_block_hashes + self.block_hashes = self.req_meta.block_hashes self.head_size = self.k_cache.shape[3] @classmethod @@ -594,29 +600,22 @@ def update_meta(self, req_meta: ReqMeta, forward_context: ForwardContext): def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): fn = getattr(self.store_instance, transfer_type) length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) precision = self.k_cache.storage().element_size() is_mla = False - block_shape = tuple(block_shape) + block_data_size = self.k_cache[0].numel() * precision offsets_k = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, + compute_layer_offset( + block_data_size, self.layer_id, is_v=False, is_mla=is_mla, ) ] * length offsets_v = [ - get_offset( - block_shape, - self.local_tp_rank, - self.total_tp_size, - precision, + compute_layer_offset( + block_data_size, self.layer_id, is_v=True, is_mla=is_mla, @@ -652,6 +651,11 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.total_num_hidden_layers = ( vllm_config.model_config.hf_config.num_hidden_layers ) + self.block_size = vllm_config.cache_config.block_size + self.block_hashes: dict[int, dict[int, list[str]]] = {} + self.rank = vllm_config.parallel_config.rank + self.is_mla = vllm_config.model_config.is_deepseek_mla + self.request_hasher = RequestHasher(vllm_config, 0) if self.role == UcmSparseRole.WORKER: ratio = 0.75 bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( @@ -667,12 +671,12 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): localRankId=self.local_tp_rank, ) kvstar_retrieve.Setup(param) - self.connector_name = ( - self._vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_name" - ] - ) - self.connector = get_kv_transfer_group().connector + # self.connector_name = ( + # self._vllm_config.kv_transfer_config.kv_connector_extra_config[ + # "ucm_connector_name" + # ] + # ) + self.connector = get_kv_transfer_group().connector.store else: self.connector = None @@ -701,7 +705,6 @@ def create_layerwise_req_state(self, req_meta, layer_name): self.local_tp_rank, self.total_tp_size, self.connector, - self.connector_name, self.kvstar_multistep_cfg, ) return self.req_states[req_meta.request_id][layer_id] @@ -726,6 +729,7 @@ def attention_begin( value: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the beginning of "unified_attention". @@ -748,6 +752,7 @@ def attention_finished( attn_output: torch.Tensor, layer_name: str, forward_context: ForwardContext, + phase: Optional[str] = None, ) -> None: """ This is called at the end of "unified_attention". @@ -759,6 +764,44 @@ def attention_finished( query, key, value, attn_output, forward_context ) + def set_block_hashes(self, req_id, token_ids): + if req_id not in self.block_hashes: + self.block_hashes[req_id] = {} + + if self.rank in self.block_hashes[req_id]: + return + + self.block_hashes[req_id][self.rank] = [] + + parent_block_hash_value = compute_parent_block_hash( + self._vllm_config.model_config.model, + self._vllm_config.parallel_config.world_size, + self._vllm_config.model_config.dtype, + seed_rank=0, + ) + + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + hash_value = self.request_hasher( + (parent_block_hash_value, curr_block_token_ids_tuple) + ) + + self.block_hashes[req_id][self.rank].append(str(hash_value)) + + parent_block_hash_value = hash_value + + if self.rank != 0 and not self.is_mla: + self.newqrequest_hasher = RequestHasher(self._vllm_config, self.rank) + for i, ucm_block_id in enumerate(self.block_hashes[req_id][self.rank]): + self.block_hashes[req_id][self.rank][i] = str( + self.newqrequest_hasher(ucm_block_id) + ) + def build_sparse_meta( self, scheduler_output, requests, input_batch, attn_metadata ) -> None: @@ -778,7 +821,7 @@ def build_sparse_meta( num_scheduled_tokens, ) in scheduler_output.num_scheduled_tokens.items(): req_state = requests[req_id] - + self.set_block_hashes(int(req_id), req_state.prompt_token_ids) q_start_loc = query_start_locs[input_batch.req_id_to_index[req_id]].item() q_len = ( query_start_locs[input_batch.req_id_to_index[req_id] + 1].item() @@ -800,6 +843,7 @@ def build_sparse_meta( q_len, self.kvstar_multistep_cfg["retrieval_stride"], req_state.prompt_token_ids, + self.block_hashes[int(req_id)][self.rank], ) self._sparse_metadata = sparse_meta diff --git a/ucm/sparse/kvstar/utils.py b/ucm/sparse/kvstar/utils.py index 35e1ddea..198b45fa 100644 --- a/ucm/sparse/kvstar/utils.py +++ b/ucm/sparse/kvstar/utils.py @@ -19,6 +19,24 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> return v_offset if is_v else k_offset +@cache +def compute_layer_offset( + block_data_size: int, + layer_id: int, + is_v: bool, + is_mla: bool, +) -> int: + layer_data_size = block_data_size if is_mla else block_data_size * 2 + + k_offset = layer_data_size * layer_id + + if is_mla: + return k_offset + + v_offset = k_offset + block_data_size + return v_offset if is_v else k_offset + + @cache def md5(input) -> int: input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) @@ -34,6 +52,14 @@ def block_hash_func(parent_block_hash, curr_block_token_ids): return md5((parent_block_hash, curr_block_token_ids_tuple)) +@cache +def compute_parent_block_hash(model_name, world_size, dtype, seed_rank=0) -> int: + meta = f"{model_name}:{world_size}:{dtype}:{seed_rank}" + meta_bytes = meta.encode("utf-8") + h_seed = hashlib.md5(meta_bytes + b"UCM_HASH_SEED").digest() + return int.from_bytes(h_seed, byteorder="big") + + def execute_command(cmd_list): with subprocess.Popen( cmd_list, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE From 627661390282a4c5b63594fcc0dc92ae4394386f Mon Sep 17 00:00:00 2001 From: wenxinwang Date: Mon, 1 Dec 2025 19:41:12 +0800 Subject: [PATCH 2/2] Adapt the YAML configuration --- examples/offline_inference_esa.py | 3 --- ucm/sparse/esa/esa.py | 18 +++++++++++------- ucm/sparse/factory.py | 7 ++++--- ucm/sparse/kvcomp/kvcomp.py | 18 +++++++++++------- ucm/sparse/kvstar/multistep.py | 8 +++++--- ucm/sparse/state.py | 12 +++++------- 6 files changed, 36 insertions(+), 30 deletions(-) diff --git a/examples/offline_inference_esa.py b/examples/offline_inference_esa.py index 88dc0198..852a8ca0 100644 --- a/examples/offline_inference_esa.py +++ b/examples/offline_inference_esa.py @@ -66,9 +66,6 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector=name, kv_connector_module_path=module_path, kv_role="kv_both", - # kv_connector_extra_config={ - # "UCM_CONFIG_FILE": "/home/externals/wangwenxin21/va_new/unified-cache-management/examples/ucm_config_example.yaml" - # }, kv_connector_extra_config={ "ucm_connectors": [ { diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index fc6226cd..ac36e54c 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from functools import cache -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -27,6 +27,7 @@ from ucm.sparse.esa.retrieval.retrieval_worker import RetrievalWorker from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config ReqType = Union[str, int] HashType = Union[str, int] @@ -202,6 +203,7 @@ def __init__( vllm_config: VllmConfig, retrieval_worker: Optional[RetrievalWorker] = None, repre_pool: Optional[ReprePool] = None, + esa_cfg: Optional[Dict[str, Any]] = None, ): self.layer_name = layer_name self.layer_id = int(layer_name.split(".")[2]) @@ -219,9 +221,7 @@ def __init__( self.rank = rank self.tp_size = tp_size self.tasks: Dict[str, Task] = {} - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "ucm_sparse_config", {} - ).get("ESA", None) + self.esa_cfg = esa_cfg self.indexes: Optional[NDArray[np.int64]] = None self.block_hashes = None self.pre_topk_block_hashes: Dict[int, str] = {} @@ -456,9 +456,12 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.connector = get_kv_transfer_group().connector.store else: self.connector = None - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["ESA"] + self.esa_cfg = ( + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("ESA") + ) self.total_num_hidden_layers = ( vllm_config.model_config.hf_config.num_hidden_layers ) @@ -533,6 +536,7 @@ def get_or_create_layerwise_req_state(self, req_meta, layer_name): self._vllm_config, self.retrieval_workers[layer_id], self.layer_pools[layer_id], + self.esa_cfg, ) return self.req_states[req_meta.request_id][layer_id] diff --git a/ucm/sparse/factory.py b/ucm/sparse/factory.py index cb1b43ae..d5b49cf3 100644 --- a/ucm/sparse/factory.py +++ b/ucm/sparse/factory.py @@ -5,6 +5,7 @@ from ucm.logger import init_logger from ucm.sparse.base import UcmSparseBase, UcmSparseRole +from ucm.utils import Config logger = init_logger(__name__) @@ -30,9 +31,9 @@ def loader() -> type[UcmSparseBase]: def create_sparse_method( cls, config: "VllmConfig", role: UcmSparseRole ) -> UcmSparseBase: - ucm_cfg = config.kv_transfer_config.kv_connector_extra_config.get( - "ucm_sparse_config" - ) + ucm_config = Config(config.kv_transfer_config) + ucm_cfg = ucm_config.get_config().get("ucm_sparse_config") + sparse_method_name, _ = next(iter(ucm_cfg.items())) if sparse_method_name in cls._registry: sparse_method_cls = cls._registry[sparse_method_name]() diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index b2483c73..27fbe67d 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -1,7 +1,7 @@ import math from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import torch @@ -31,6 +31,7 @@ from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.sparse.state import get_ucm_sparse from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config logger = init_logger(__name__) @@ -51,6 +52,7 @@ def __init__( vllm_config: VllmConfig, retrieval_worker: Optional[HashRetrievalWorker] = None, repre_pool: Optional[ReprePool] = None, + esa_cfg: Optional[Dict[str, Any]] = None, ): super().__init__( layer_name, @@ -62,9 +64,7 @@ def __init__( repre_pool, ) - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["KvComp"] + self.esa_cfg = esa_cfg # `retrieval_worker` 类型是 HashRetrievalWorker self.retrieval_worker = retrieval_worker @@ -162,9 +162,12 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData() self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData() self._sparse_metadata: ESASparseMetaData = ESASparseMetaData() - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["KvComp"] + self.esa_cfg = ( + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("KvComp") + ) self.block_size = vllm_config.cache_config.block_size self.block_hashes: dict[int, dict[int, list[str]]] = {} @@ -271,6 +274,7 @@ def get_or_create_layerwise_req_state(self, req_meta, layer_name): self._vllm_config, self.retrieval_workers[layer_id], self.layer_pools[layer_id], + self.esa_cfg, ) return self.req_states[req_meta.request_id][layer_id] diff --git a/ucm/sparse/kvstar/multistep.py b/ucm/sparse/kvstar/multistep.py index 79b50e15..18ed4cb8 100644 --- a/ucm/sparse/kvstar/multistep.py +++ b/ucm/sparse/kvstar/multistep.py @@ -25,6 +25,7 @@ get_bind_cpus_for_rank, ) from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.utils import Config """ -------------------------------------------------------------------------------------- @@ -683,9 +684,10 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): assert self._vllm_config.kv_transfer_config is not None self.kvstar_multistep_cfg = ( - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["KVStarMultiStep"] + Config(vllm_config.kv_transfer_config) + .get_config() + .get("ucm_sparse_config") + .get("KVStarMultiStep") ) print(f"kvstar_multistep_cfg: {self.kvstar_multistep_cfg}") diff --git a/ucm/sparse/state.py b/ucm/sparse/state.py index a4e93c8d..a0f77a53 100644 --- a/ucm/sparse/state.py +++ b/ucm/sparse/state.py @@ -11,6 +11,7 @@ from ucm.logger import init_logger from ucm.sparse.base import UcmSparseBase, UcmSparseRole from ucm.sparse.factory import UcmSparseFactory +from ucm.utils import Config if TYPE_CHECKING: from vllm.config import VllmConfig @@ -37,15 +38,12 @@ def ensure_ucm_sparse_initialized( return # Check if UCM sparse is enabled - if ( - "ucm_sparse_config" - not in vllm_config.kv_transfer_config.kv_connector_extra_config - ): + ucm_config = Config(vllm_config.kv_transfer_config) + ucm_sparse_config = ucm_config.get_config().get("ucm_sparse_config") + if not ucm_sparse_config: return - sparse_method_name = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ] + sparse_method_name = ucm_sparse_config if _UCM_SPARSE_AGENT is None: logger.info("Initializing UCM sparse agent with method: %s", sparse_method_name)