From 50ab55d3830f510d327c8f4f7ca1ee0ddb25e74e Mon Sep 17 00:00:00 2001 From: qyh Date: Wed, 10 Sep 2025 14:57:17 +0800 Subject: [PATCH 1/2] refactor ucconnector --- test/test_uc_connector.py | 217 ++++++-------- ucm/integration/vllm/uc_connector.py | 432 ++++++++++++--------------- ucm/store/ucm_nfs_store.py | 20 +- 3 files changed, 297 insertions(+), 372 deletions(-) diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py index 1753783e..97e7222b 100644 --- a/test/test_uc_connector.py +++ b/test/test_uc_connector.py @@ -25,7 +25,7 @@ import random import secrets import unittest -from typing import List +from typing import List, Union from unittest.mock import MagicMock, Mock, patch import torch @@ -34,9 +34,9 @@ from vllm.v1.request import Request from ucm.integration.vllm.uc_connector import ( - LoadPara, + BlockOperation, ReqMeta, - SavePara, + RequestBlockInfo, UCConnectorV1Metadata, UnifiedCacheConnectorV1, ) @@ -100,10 +100,8 @@ def init_uc( ucconnector.rank = 1 ucconnector.is_mla = False ucconnector.connector = mock_connector - ucconnector.load_paras: dict[str, LoadPara] = {} - ucconnector.save_paras: dict[str, SavePara] = {} + ucconnector.request_block_infos: dict[str, RequestBlockInfo] = {} ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {} - ucconnector.load_tasks: dict[str, tuple[Task, Task]] = {} ucconnector.total_tp_size = 2 ucconnector._connector_metadata = metadata ucconnector.layerwise_load_tasks: dict[ @@ -114,14 +112,47 @@ def init_uc( ucconnector._load_req_to_blocks: dict[str, set[int]] = {} return ucconnector - def test_get_num_new_matched_tokens_hit(self): + def test_get_num_new_matched_tokens_hit_all_on_storage(self): mock_connector = Mock(spec=UcmKVStoreBase) def mock_lookup(tokens: List[int]) -> List[bool]: return [True] * self.block_number + mock_connector.lookup.side_effect = mock_lookup + ucconnector = self.init_uc(mock_connector) + + random.seed(20250704) + request1 = make_request( + request_id=1, + prompt_token_ids=random.sample( + range(0, 10000), self.block_number * self.block_size + ), + mm_positions=None, + mm_hashes=None, + ) + + # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size + all_tokens_len = len(request1.all_token_ids) + external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0) + self.assertEqual(external_tokens, all_tokens_len - self.block_size) + self.assertEqual( + ucconnector.request_block_infos[request1.request_id].block_operations, + [ + BlockOperation.LOAD, + BlockOperation.LOAD, + BlockOperation.LOAD, + BlockOperation.NONE, + ], + ) + + def test_get_num_new_matched_tokens_partial_hit(self): + mock_connector = Mock(spec=UcmKVStoreBase) + + def mock_lookup(tokens: List[int]) -> List[bool]: + return [True, False, True, False] + def mock_create(tokens: List[str]) -> List[int]: - return [1] * self.block_number + return [0, 1, 0] mock_connector.lookup.side_effect = mock_lookup mock_connector.create.side_effect = mock_create @@ -137,10 +168,60 @@ def mock_create(tokens: List[str]) -> List[int]: mm_hashes=None, ) - # all block dumped in ssd, external_tokens equals to full tokens num + # all block dumped in ssd, external_tokens equals to full tokens num - self.block_size all_tokens_len = len(request1.all_token_ids) external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0) - self.assertEqual(external_tokens, all_tokens_len - self.block_size) + self.assertEqual(external_tokens, self.block_size) + self.assertEqual( + ucconnector.request_block_infos[request1.request_id].block_operations, + [ + BlockOperation.LOAD, + BlockOperation.DUMP, + BlockOperation.NONE, + BlockOperation.DUMP, + ], + ) + + def test_get_num_new_matched_tokens_partial_hit_with_preftxcache(self): + mock_connector = Mock(spec=UcmKVStoreBase) + + def mock_lookup(tokens: List[int]) -> List[bool]: + return [False, True, False] + + def mock_create(tokens: List[str]) -> List[int]: + return [0, 1, 0] + + mock_connector.lookup.side_effect = mock_lookup + mock_connector.create.side_effect = mock_create + ucconnector = self.init_uc(mock_connector) + + random.seed(20250704) + request1 = make_request( + request_id=1, + prompt_token_ids=random.sample( + range(0, 10000), self.block_number * self.block_size + ), + mm_positions=None, + mm_hashes=None, + ) + + # no block dumped in ssd, external_tokens equals to 0 + external_tokens, _ = ucconnector.get_num_new_matched_tokens( + request1, self.block_size + ) + self.assertEqual(external_tokens, 0) + self.assertEqual( + ucconnector.request_block_infos[request1.request_id].start_position, 1 + ) + self.assertEqual( + ucconnector.request_block_infos[request1.request_id].block_operations, + [ + BlockOperation.NONE, + BlockOperation.DUMP, + BlockOperation.NONE, + BlockOperation.DUMP, + ], + ) def test_get_num_new_matched_tokens_no_hit(self): mock_connector = Mock(spec=UcmKVStoreBase) @@ -149,7 +230,7 @@ def mock_lookup(tokens: List[int]) -> List[bool]: return [False] * self.block_number def mock_create(tokens: List[str]) -> List[int]: - return [1] * self.block_number + return [0] * self.block_number mock_connector.lookup.side_effect = mock_lookup mock_connector.create.side_effect = mock_create @@ -192,15 +273,9 @@ def test_get_num_new_matched_tokens_invalid_para(self): def test_wait_for_save_not_layerwise_success(self): req_meta1 = MagicMock(spec=ReqMeta) req_meta1.request_id = "req1" - req_meta1.save_paras = SavePara( - num_blocks_need_save=self.block_number, - start_save_position=0, - num_blocks_to_save=self.block_number, - ) - req_meta1.save_paras.block_hashes = [ - secrets.token_hex(8) for _ in range(self.block_number) + req_meta1.dump_blocks = [ + (secrets.token_hex(8), i) for i in range(self.block_number) ] - req_meta1.vllm_block_ids = list(range(self.block_number)) metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -236,15 +311,10 @@ def test_wait_for_save_not_layerwise_invalid_para(self): def test_start_load_kv_not_layerwise_success(self): req_meta1 = MagicMock(spec=ReqMeta) req_meta1.request_id = "req1" - req_meta1.load_paras = LoadPara( - vllm_cached_tokens=1 * self.block_size, - storage_cached_tokens=self.block_number * self.block_size, - can_load=True, - ) - req_meta1.load_paras.block_hashes = [ - secrets.token_hex(8) for _ in range(self.block_number) + req_meta1.load_blocks = [ + (secrets.token_hex(8), i) for i in range(self.block_number) ] - req_meta1.vllm_block_ids = list(range(self.block_number)) + req_meta1.load_async = False metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -282,15 +352,9 @@ def test_start_load_kv_invalid_para(self): def test_start_load_kv_layerwise_success(self): req_meta1 = MagicMock(spec=ReqMeta) req_meta1.request_id = "req1" - req_meta1.load_paras = LoadPara( - vllm_cached_tokens=1 * self.block_size, - storage_cached_tokens=self.block_number * self.block_size, - can_load=True, - ) - req_meta1.load_paras.block_hashes = [ - secrets.token_hex(8) for _ in range(self.block_number) + req_meta1.load_blocks = [ + (secrets.token_hex(8), i) for i in range(self.block_number) ] - req_meta1.vllm_block_ids = list(range(self.block_number)) metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -309,89 +373,6 @@ def mock_load( ucconnector.start_load_kv(forward_context) assert mock_connector.load.call_count == 2 * self.num_layers - def test_generate_layerwise_load_tasks_success(self): - # init implement - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_load( - block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] - ) -> Task: - assert offset is not None and offset - assert dst_tensor is not None and dst_tensor - return Task() - - mock_connector.load.side_effect = mock_load - ucconnector = self.init_uc(mock_connector) - - # provide generate_layerwise_load_tasks params - fetch_block_ids = list(range(self.block_number * 2)) - fetch_block_hashes = [ - secrets.token_hex(8) for _ in range(self.block_number * 2) - ] - layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {} - current_layer = 0 - for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = ucconnector.get_tensor_and_offset_layerwise( - fetch_block_ids, kv_layer, layer_name - ) - layer_to_tensor[layer_name] = (tensors, offsets) - current_layer += 1 - # generate layerwise tasks - layerwise_load_task = ucconnector.generate_layerwise_load_tasks( - fetch_block_hashes, layer_to_tensor - ) - - for i in range(self.num_layers): - task = next(layerwise_load_task) - assert task is not None, f"layer {i} is None" - assert mock_connector.load.call_count == self.num_layers * 2 - - def test_generate_layerwise_load_tasks_invalid_params(self): - # init implement - mock_connector = Mock(spec=UcmKVStoreBase) - - def mock_load( - block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] - ) -> Task: - assert offset is not None and offset - assert dst_tensor is not None and dst_tensor - return Task() - - mock_connector.load.side_effect = mock_load - ucconnector = self.init_uc(mock_connector) - - # provide generate_layerwise_load_tasks params - fetch_block_ids = list(range(self.block_number * 2)) - fetch_block_hashes = [ - secrets.token_hex(8) for _ in range(self.block_number * 2) - ] - layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {} - for layer_name, kv_layer in self.kv_caches.items(): - tensors, offsets = ucconnector.get_tensor_and_offset_layerwise( - fetch_block_ids, kv_layer, layer_name - ) - layer_to_tensor[layer_name] = (tensors, offsets) - # generate layerwise tasks - layerwise_load_task = ucconnector.generate_layerwise_load_tasks( - [], layer_to_tensor - ) - with self.assertRaises(AssertionError) as context: - next(layerwise_load_task) - self.assertEqual( - str(context.exception), - "The block hashes need to be fetched should not be None or empty.", - ) - - layerwise_load_task = ucconnector.generate_layerwise_load_tasks( - fetch_block_hashes, None - ) - with self.assertRaises(AssertionError) as context: - next(layerwise_load_task) - self.assertEqual( - str(context.exception), - "The layers of tensor need to be fetched should not be None or empty.", - ) - if __name__ == "__main__": unittest.main() diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 21ce9dba..9c38de51 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -26,6 +26,7 @@ import hashlib import pickle from dataclasses import dataclass, field +from enum import Enum from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union import torch @@ -52,70 +53,36 @@ logger = init_logger(__name__) -@dataclass -class LoadPara: - # Number of tokens cached in vLLM - vllm_cached_tokens: int = 0 - # Number of tokens cached in ssd - storage_cached_tokens: int = 0 - # Whether the scheduler allow us to load the blocks - can_load: bool = False - # block hashes - block_hashes: list[str] = field(default_factory=list) +class BlockOperation(Enum): + NONE = "none" + LOAD = "load" + DUMP = "dump" @dataclass -class SavePara: - # dump block ids - num_blocks_need_save: int = 0 - # start save position - start_save_position: int = 0 - # block hashes +class RequestBlockInfo: + # Hash values for all blocks block_hashes: list[str] = field(default_factory=list) - # num of blocks prepare to save - num_blocks_to_save: int = 0 - # num of blocks already saved - num_blocks_saved: int = 0 + # Operation type for each block + block_operations: list[BlockOperation] = field(default_factory=list) + # Next block position to process + start_position: int = 0 @dataclass class ReqMeta: - # Request ID, unique for each request request_id: str - # Request block id in vllm - vllm_block_ids: list[int] - # Load information - load_paras: Optional[LoadPara] = None - # Save information - save_paras: Optional[SavePara] = None - # Mark request which need load async + # list[(block_hash, vllm_block_id)] + load_blocks: list[tuple[str, int]] = field(default_factory=list) + # list[(block_hash, vllm_block_id)] + dump_blocks: list[tuple[str, int]] = field(default_factory=list) + # Whether use load_async load_async: bool = False @dataclass class UCConnectorV1Metadata(KVConnectorMetadata): - requests: list[ReqMeta] - - def __init__(self): - self.requests = [] - - def add_request( - self, - request_id: str, - vllm_block_ids: list[int], - load_paras: Optional[LoadPara] = None, - save_paras: Optional[SavePara] = None, - load_async: bool = False, - ) -> None: - self.requests.append( - ReqMeta( - request_id=request_id, - vllm_block_ids=vllm_block_ids, - load_paras=load_paras, - save_paras=save_paras, - load_async=load_async, - ) - ) + requests: list[ReqMeta] = field(default_factory=list) class UnifiedCacheConnectorV1(KVConnectorBase_V1): @@ -129,8 +96,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.rank = ( -1 if role == KVConnectorRole.SCHEDULER else get_world_group().local_rank ) - self.load_paras: dict[str, LoadPara] = {} - self.save_paras: dict[str, SavePara] = {} + self.request_block_infos: dict[str, RequestBlockInfo] = {} # dump tasks record request -> block -> list[task] self.dump_tasks: dict[str, dict[str, List[Task]]] = {} self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {} @@ -244,41 +210,6 @@ def get_tensor_and_offset_layerwise( v_offsets.append(v_data_offset) return k_tensors + v_tensors, k_offsets + v_offsets - def generate_layerwise_load_tasks( - self, - fetch_block_hashes, - layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]], - ) -> Generator[Optional[tuple[Task, Task]], None, None]: - - logger.debug(f"fetch_block_hashes is {fetch_block_hashes}") - assert ( - fetch_block_hashes is not None and fetch_block_hashes - ), "The block hashes need to be fetched should not be None or empty." - assert ( - layer_to_tensor is not None and layer_to_tensor - ), "The layers of tensor need to be fetched should not be None or empty." - - blocks_len = len(fetch_block_hashes) - - def load(tensor_list, offset_list) -> tuple[Task, Task]: - k_load_task = self.connector.load( - fetch_block_hashes, offset_list[:blocks_len], tensor_list[:blocks_len] - ) - v_load_task = None - if not self.is_mla: - v_load_task = self.connector.load( - fetch_block_hashes, - offset_list[blocks_len:], - tensor_list[blocks_len:], - ) - return k_load_task, v_load_task - - for layer_name, (tensor_list, offset_list) in layer_to_tensor.items(): - logger.debug(f"Start execute {layer_name} load task.") - yield load(tensor_list, offset_list) - - yield None - # ============================== # Worker-side methods # ============================== @@ -316,42 +247,26 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.layerwise_load_tasks.clear() self.current_layer = 0 for request in metadata.requests: - if request.load_paras is None or not request.load_paras.can_load: + if not request.load_blocks: continue - block_ids = request.vllm_block_ids - # Blocks id need to save should start after last vllm cached block - load_start_block_id = ( - request.load_paras.vllm_cached_tokens // self.block_size - ) - load_end_block_id = ( - request.load_paras.storage_cached_tokens // self.block_size - ) - fetch_block_ids = block_ids[load_start_block_id:load_end_block_id] - logger.debug( - f"fetch_block_ids = {fetch_block_ids},\n" - f"load_start_block_id = {load_start_block_id},\n" - f"load_end_block_id = {load_end_block_id},\n" - f"fetch_block_ids = {fetch_block_ids}" - ) - fetch_block_hashes = request.load_paras.block_hashes[ - load_start_block_id:load_end_block_id - ] - assert len(fetch_block_ids) == len(fetch_block_hashes) - blocks_len = len(fetch_block_ids) + + storage_block_ids = [block[0] for block in request.load_blocks] + vllm_block_ids = [block[1] for block in request.load_blocks] + blocks_len = len(storage_block_ids) self._load_req_to_blocks.setdefault(request.request_id, set()).update( - fetch_block_ids + vllm_block_ids ) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( - fetch_block_ids, kv_layer, layer_name + vllm_block_ids, kv_layer, layer_name ) k_task_id = self.connector.load( - fetch_block_hashes, offsets[:blocks_len], tensors[:blocks_len] + storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] ) v_task_id = None if not self.is_mla: v_task_id = self.connector.load( - fetch_block_hashes, + storage_block_ids, offsets[blocks_len:], tensors[blocks_len:], ) @@ -362,7 +277,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: v_task_id, ) - if request.load_async: + if request.load_async and request.request_id in self.layerwise_load_tasks: for _, (k_task, v_task) in self.layerwise_load_tasks[ request.request_id ].items(): @@ -373,7 +288,10 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self._need_load_reqs[request.request_id].append(v_task) continue - if not self.use_layerwise: + if ( + not self.use_layerwise + and request.request_id in self.layerwise_load_tasks + ): for _, (k_task, v_task) in self.layerwise_load_tasks[ request.request_id ].items(): @@ -409,10 +327,16 @@ def wait_for_layer_load(self, layer_name: str) -> None: k_task, v_task = layer_to_task[layer_name] if self.connector.wait(k_task) != 0: self._load_failed_reqs.add(request_id) + logger.error( + f"Failed to load block for request {request_id} on layer {layer_name}" + ) continue if not self.is_mla: if self.connector.wait(v_task) != 0: self._load_failed_reqs.add(request_id) + logger.error( + f"Failed to load block for request {request_id} on layer {layer_name}" + ) continue logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.") @@ -446,31 +370,18 @@ def save_kv_layer( metadata = self._get_connector_metadata() assert isinstance(metadata, UCConnectorV1Metadata) - assert attn_metadata is not None, "The attn_metadata should not be None." for request in metadata.requests: - if request.save_paras is None or request.load_async: + if not request.dump_blocks or request.load_async: continue - save_param = request.save_paras - vllm_block_ids = request.vllm_block_ids[ - save_param.start_save_position : save_param.start_save_position - + save_param.num_blocks_to_save - ] - blocks_len = len(vllm_block_ids) + storage_block_ids = [block[0] for block in request.dump_blocks] + vllm_block_ids = [block[1] for block in request.dump_blocks] + blocks_len = len(storage_block_ids) tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name ) - storage_block_ids = save_param.block_hashes[ - save_param.num_blocks_saved : save_param.num_blocks_saved - + save_param.num_blocks_to_save - ] - logger.debug( - f"blocks length = {blocks_len},\n" - f"length of offsets = {len(offsets)},\n" - f"length of need save vllm_block_ids = {len(vllm_block_ids)},\n" - f"length of storage_block_ids = {len(storage_block_ids)},\n" - ) + if kv_layer[0].device.type == "npu": torch.npu.current_stream().synchronize() elif kv_layer[0].device.type == "cuda": @@ -521,35 +432,18 @@ def wait_for_tasks(): return success_dumped_blocks if success_dumped_blocks else None for request in metadata.requests: - if request.save_paras is None: + if not request.dump_blocks: continue - save_paras = request.save_paras - logger.debug( - f"num_blocks_saved = {save_paras.num_blocks_saved},\n" - f"num_blocks_to_save = {save_paras.num_blocks_to_save}\n" - ) - start_pos = save_paras.start_save_position - num_blocks = save_paras.num_blocks_to_save - num_blocks_saved = save_paras.num_blocks_saved - dump_block_ids = request.vllm_block_ids[start_pos : start_pos + num_blocks] - dump_vllm_block_hashes = save_paras.block_hashes[ - num_blocks_saved : num_blocks_saved + num_blocks - ] - - logger.debug( - f"dump block ids is {dump_block_ids},\n" - f"dump_vllm_block_hashes is {dump_vllm_block_hashes}\n" - ) - - assert len(dump_block_ids) == len(dump_vllm_block_hashes) - blocks_len = len(dump_block_ids) + storage_block_ids = [block[0] for block in request.dump_blocks] + vllm_block_ids = [block[1] for block in request.dump_blocks] + blocks_len = len(storage_block_ids) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( - dump_block_ids, kv_layer, layer_name + vllm_block_ids, kv_layer, layer_name ) for block_id, offset, tensor in zip( - dump_vllm_block_hashes, offsets[:blocks_len], tensors[:blocks_len] + storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] ): task = self.connector.dump([block_id], [offset], [tensor]) self.dump_tasks.setdefault(request.request_id, {}).setdefault( @@ -557,7 +451,7 @@ def wait_for_tasks(): ).append(task) if not self.is_mla: for block_id, offset, tensor in zip( - dump_vllm_block_hashes, + storage_block_ids, offsets[blocks_len:], tensors[blocks_len:], ): @@ -634,22 +528,35 @@ def md5(input) -> int: if not block_hashes: logger.debug("Maybe tokens too short to load.") return 0, False - hit_masks = self.connector.lookup(block_hashes) - num_external_computed_tokens = sum(hit_masks) * self.block_size - # When all the tokens are cached in ssd and can be divided by block size, - # we need to recompute the last token. This if condition will be removed - # once vLLM's scheduler provides a better solution in the future. - if num_external_computed_tokens == request.num_tokens: - num_external_computed_tokens -= self.block_size - self.load_paras[request.request_id] = LoadPara( - vllm_cached_tokens=num_computed_tokens, - storage_cached_tokens=num_external_computed_tokens, - block_hashes=block_hashes, - can_load=False, + + # Calculate start position (exclude blocks already in HBM) + start_position = num_computed_tokens // self.block_size + + block_operations = [BlockOperation.NONE] * len(block_hashes) + + remain_hashes = block_hashes[start_position:] + if not remain_hashes: + # All blocks are in HBM + return 0, False + + lookup_results = self.connector.lookup(remain_hashes) + + # Find the longest continuous match from the beginning + num_lookup_hits = 0 + for i, hit in enumerate(lookup_results): + if hit: + num_lookup_hits += 1 + block_operations[start_position + i] = BlockOperation.LOAD + else: + # TODO we will fix hole match later + break + logger.info( + f"\nnum_total_blocks: {len(block_hashes)}\n" + f"\nnum_lookup_hits on hbm: {start_position}\n" + f"\nnum_lookup_hits on storage except hbm: {num_lookup_hits}\n" ) - need_load_tokens = max(num_external_computed_tokens - num_computed_tokens, 0) - # Load async when Decode instance need to load. + # Load async when Decode instance need to load.kv_consumer" if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": # Only trigger 1 asynchronous KV transfer per request. if ( @@ -659,31 +566,42 @@ def md5(input) -> int: return 0, False request.kv_transfer_params = request.kv_transfer_params or {} request.kv_transfer_params["load_async"] = False - if need_load_tokens > 0: + if num_lookup_hits > 0: + self.request_block_infos[request.request_id] = RequestBlockInfo( + block_hashes=block_hashes, + block_operations=block_operations, + start_position=start_position, + ) self._need_load_reqs[request.request_id] = [] - return need_load_tokens, True - - num_max_cached_tokens = max(num_external_computed_tokens, num_computed_tokens) - num_blocks_need_save = ( - len(request.all_token_ids) - num_max_cached_tokens - ) // self.block_size - if num_blocks_need_save > 0: - start_save_position = num_max_cached_tokens // self.block_size - need_allocate_block_hashes = block_hashes[start_save_position:] - rets = self.connector.create(need_allocate_block_hashes) - if rets and all(ret == 0 for ret in rets): - self.save_paras[request.request_id] = SavePara( - num_blocks_need_save=num_blocks_need_save, - start_save_position=start_save_position, - block_hashes=need_allocate_block_hashes, + return num_lookup_hits * self.block_size, True + + # Create blocks for the remaining (unmatched) blocks + if num_lookup_hits < len(remain_hashes): + remaining_hashes = remain_hashes[num_lookup_hits:] + create_results = self.connector.create(remaining_hashes) + logger.info(f"\ncreate_results on storage: {create_results}\n") + for j, ret in enumerate(create_results): + idx = num_lookup_hits + j + block_operations[start_position + idx] = ( + BlockOperation.DUMP if ret == 0 else BlockOperation.NONE ) - logger.debug( - f"num_blocks_need_save = {num_blocks_need_save},\n" - f"num_external_computed_tokens = {num_external_computed_tokens},\n" - f"num_computed_tokens = {num_computed_tokens}.\n" + + # When all the tokens are cached in ssd or hbm, + # we need to recompute the last token. This if condition will be removed + # once vLLM's scheduler provides a better solution in the future. + if (num_lookup_hits + start_position) * self.block_size == len( + request.all_token_ids + ): + num_lookup_hits -= 1 + block_operations[-1] = BlockOperation.NONE + + self.request_block_infos[request.request_id] = RequestBlockInfo( + block_hashes=block_hashes, + block_operations=block_operations, + start_position=start_position, ) - return need_load_tokens, False + return num_lookup_hits * self.block_size, False def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int @@ -691,13 +609,6 @@ def update_state_after_alloc( """ Update KVConnector state after block allocation. """ - if request.request_id not in self.load_paras: - # No KV tokens from external KV cache, return - return - - if num_external_tokens > 0: - self.load_paras[request.request_id].can_load = True - if request.request_id in self._need_load_reqs: local_block_ids = ( blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] @@ -719,32 +630,38 @@ def build_connector_meta( meta = UCConnectorV1Metadata() for req_id, block_ids in self._need_load_reqs.items(): - meta.add_request( - req_id, - vllm_block_ids=block_ids, - load_paras=self.load_paras[req_id], - load_async=True, + block_info = self.request_block_infos.get(req_id) + if block_info: + load_blocks, dump_blocks = self._extract_blocks(block_ids, block_info) + meta.requests.append( + ReqMeta( + request_id=req_id, + load_blocks=load_blocks, + dump_blocks=dump_blocks, + load_async=True, + ) ) self._need_load_reqs.clear() for new_req in scheduler_output.scheduled_new_reqs: - # Load kv is only supported for new reqs - new_scheduled_blocks = ( - scheduler_output.num_scheduled_tokens[new_req.req_id] // self.block_size - ) - load_paras = self.load_paras.pop(new_req.req_id, None) - save_paras = self.save_paras.get(new_req.req_id, None) - if save_paras is not None: - save_paras.num_blocks_to_save = new_scheduled_blocks - meta.add_request( - new_req.req_id, - vllm_block_ids=new_req.block_ids[0], - load_paras=load_paras, - save_paras=save_paras, - ) - # clear all load_paras when build meta for new reqs done - self.load_paras.clear() + req_id = new_req.req_id + vllm_block_ids = new_req.block_ids[0] + block_info = self.request_block_infos.get(req_id) + if block_info: + load_blocks, dump_blocks = self._extract_blocks( + vllm_block_ids, block_info + ) + if load_blocks or dump_blocks: + meta.requests.append( + ReqMeta( + request_id=req_id, + load_blocks=load_blocks, + dump_blocks=dump_blocks, + ) + ) + + # Process cached requests using iterator cached_request_data = scheduler_output.scheduled_cached_reqs # Adapted for vllm 0.9.1, 0.9.2 and later versions @@ -770,25 +687,19 @@ def get_requests(): # When prompt tokens > max_num_batched_tokens, request of running requests may need to save for req_id, new_block_ids in get_requests(): - save_paras = self.save_paras.get(req_id) - if save_paras is None: - continue - - save_paras.num_blocks_saved += save_paras.num_blocks_to_save - - if save_paras.num_blocks_need_save > save_paras.num_blocks_saved: - logger.debug(f"Running request {req_id} has blocks to save") - save_paras.start_save_position = 0 - new_scheduled_blocks = ( - scheduler_output.num_scheduled_tokens[req_id] // self.block_size - ) - save_paras.num_blocks_to_save = new_scheduled_blocks - meta.add_request( - req_id, - vllm_block_ids=new_block_ids[0], - load_paras=None, - save_paras=save_paras, + block_info = self.request_block_infos.get(req_id) + if block_info: + load_blocks, dump_blocks = self._extract_blocks( + new_block_ids[0], block_info ) + if load_blocks or dump_blocks: + meta.requests.append( + ReqMeta( + request_id=req_id, + load_blocks=load_blocks, + dump_blocks=dump_blocks, + ) + ) return meta @@ -797,21 +708,54 @@ def request_finished( request: "Request", block_ids: list[int], ) -> tuple[bool, Optional[dict[str, Any]]]: - # clear save_paras for request - save_paras = self.save_paras.pop(request.request_id, None) + block_info = self.request_block_infos.pop(request.request_id, None) if hasattr(request, "succeed_dumped_blocks") and request.succeed_dumped_blocks: + logger.debug(f"commit {request.succeed_dumped_blocks} to True.") self.connector.commit(request.succeed_dumped_blocks, True) - if save_paras is not None: + if block_info is not None: cancel_blocks = [ - block - for block in save_paras.block_hashes - if hasattr(request, "succeed_dumped_blocks") - and block not in request.succeed_dumped_blocks + block_info.block_hashes[i] + for i, op in enumerate(block_info.block_operations) + if op == BlockOperation.DUMP + and hasattr(request, "succeed_dumped_blocks") + and block_info.block_hashes[i] not in request.succeed_dumped_blocks ] if cancel_blocks: + logger.warning(f"commit {cancel_blocks} to False.") self.connector.commit(cancel_blocks, False) return False, None + def _extract_blocks( + self, vllm_block_ids: list[int], block_info: RequestBlockInfo + ) -> tuple[list[tuple[str, int]], list[tuple[str, int]]]: + """ + Extract blocks that need load and dump, block_info.start_position + is the next block position to process, only return blocks that need + processing, NONE blocks are ignored. + """ + start_pos = block_info.start_position + + if start_pos >= len(block_info.block_operations): + return [], [] + + process_length = min( + len(block_info.block_operations) - start_pos, len(vllm_block_ids) + ) + ops = block_info.block_operations[start_pos : start_pos + process_length] + hashes = block_info.block_hashes[start_pos : start_pos + process_length] + vllm_ids = vllm_block_ids[:process_length] + + load_blocks = [] + dump_blocks = [] + for op, hash, vllm_id in zip(ops, hashes, vllm_ids): + if op == BlockOperation.LOAD: + load_blocks.append((hash, vllm_id)) + elif op == BlockOperation.DUMP: + dump_blocks.append((hash, vllm_id)) + + block_info.start_position += process_length + return load_blocks, dump_blocks + def get_block_ids_with_load_errors(self) -> set[int]: invalid_block_ids: set[int] = set() for req_id in self._load_failed_reqs: diff --git a/ucm/store/ucm_nfs_store.py b/ucm/store/ucm_nfs_store.py index 9fa51fa3..81595223 100644 --- a/ucm/store/ucm_nfs_store.py +++ b/ucm/store/ucm_nfs_store.py @@ -56,8 +56,10 @@ def __init__(self, config: Dict): param = ucmnfsstore.SetupParam(storage_backends, block_size, enableTransfer) if enableTransfer: param.transferDeviceId = device_id - param.transferStreamNumber = config["transferStreamNumber"] - param.transferIoSize = config["transferIoSize"] + if "transferStreamNumber" in config: + param.transferStreamNumber = config["transferStreamNumber"] + if "transferIoSize" in config: + param.transferIoSize = config["transferIoSize"] ret = ucmnfsstore.Setup(param) if ret != 0: msg = f"Failed to initialize ucmnfsstore, errcode: {ret}." @@ -77,9 +79,12 @@ def create(self, block_ids: List[str]) -> List[int]: """ rets = ucmnfsstore.AllocBatch(block_ids) if rets and all(ret == 0 for ret in rets): - logger.info("Succeed in allocating kv cache space.") + logger.debug("Succeed in allocating kv cache space.") else: failed_blocks = [block_ids[i] for i, ret in enumerate(rets) if ret != 0] + logger.warning( + f"Failed to allocate kv cache space for blocks: {failed_blocks}." + ) return rets def lookup(self, block_ids: List[str]) -> List[bool]: @@ -126,7 +131,7 @@ def load( block_ids, offset, dst_tensor_ptr, dst_tensor_size ) logger.debug( - f"Succeed in loading kv cache , task id: {task_id}, offset: {offset}." + f"Succeed in loading kv cache , task id: {task_id}, offset: {offset}, dst_tensor_size {dst_tensor_size}." ) return NfsTask(task_id=task_id) @@ -149,7 +154,7 @@ def dump( block_ids, offset, src_tensor_ptr, src_tensor_size ) logger.debug( - f"Succeed in dumping kv cache, task id: {task_id}, offset {offset}." + f"Succeed in dumping kv cache, task id: {task_id}, offset {offset}, src_tensor_size {src_tensor_size}." ) return NfsTask(task_id=task_id) @@ -169,8 +174,6 @@ def wait(self, task: Task) -> int: ret = ucmnfsstore.Wait(task.get_id()) if ret != 0: logger.error(f"Failed to wait for kv cache transfer task, errcode: {ret}.") - else: - logger.debug("Succeed in waiting for kv cache transfer task.") return ret def commit(self, block_ids: List[str], is_success: bool = True) -> None: @@ -181,10 +184,7 @@ def commit(self, block_ids: List[str], is_success: bool = True) -> None: block_ids (List[str]): vLLM block hash. is_success(bool): if False, we need release block """ - if not is_success: - logger.warning(f"commit {block_ids} to {is_success}") ucmnfsstore.CommitBatch(block_ids, is_success) - logger.debug("Succeed in committing kv cache.") def check(self, task: Task) -> int: """ From ac7b6a37967c7cb059639414c4309924e29660a1 Mon Sep 17 00:00:00 2001 From: qyh Date: Thu, 11 Sep 2025 19:51:21 +0800 Subject: [PATCH 2/2] fic comment --- ucm/integration/vllm/uc_connector.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 9c38de51..cce4706c 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -375,8 +375,13 @@ def save_kv_layer( if not request.dump_blocks or request.load_async: continue + # Extract storage block IDs and vLLM block IDs from dump_blocks, same for load_blocks + # dump_blocks format: [(block_hash, vllm_block_id), ...] + # Note: block_hash is the storage_block_id + # Example: [("hash_123", 5), ("hash_456", 8), ("hash_789", 12)] + # ["hash_123", "hash_456", "hash_789"] storage_block_ids = [block[0] for block in request.dump_blocks] - vllm_block_ids = [block[1] for block in request.dump_blocks] + vllm_block_ids = [block[1] for block in request.dump_blocks] # [5, 8, 12] blocks_len = len(storage_block_ids) tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name @@ -556,7 +561,7 @@ def md5(input) -> int: f"\nnum_lookup_hits on storage except hbm: {num_lookup_hits}\n" ) - # Load async when Decode instance need to load.kv_consumer" + # Load async when Decode instance need to load if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": # Only trigger 1 asynchronous KV transfer per request. if (