diff --git a/test/test_uc_connector.py b/test/test_uc_connector.py index d4a0caeb..0c2261d8 100644 --- a/test/test_uc_connector.py +++ b/test/test_uc_connector.py @@ -25,6 +25,7 @@ import random import secrets import unittest +from collections import defaultdict from typing import List, Union from unittest.mock import MagicMock, Mock, patch @@ -106,12 +107,14 @@ def init_uc( ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {} ucconnector.total_tp_size = self.total_tp_size ucconnector._connector_metadata = metadata - ucconnector.layerwise_load_tasks: dict[ - str, dict[str, tuple[Task, Task]] - ] = {} + ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict( + dict + ) ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {} ucconnector._load_failed_reqs: set[str] = set() ucconnector._load_req_to_blocks: dict[str, set[int]] = {} + ucconnector.num_layers = 48 + ucconnector.is_mla = False return ucconnector def test_get_num_new_matched_tokens_hit_all_on_storage(self): @@ -508,6 +511,7 @@ def test_wait_for_save_not_layerwise_invalid_para(self): ucconnector.block_size = self.block_size ucconnector.use_layerwise = False ucconnector._connector_metadata = Mock() + ucconnector.is_mla = False with self.assertRaises(AssertionError): ucconnector.wait_for_save() @@ -542,6 +546,7 @@ def mock_wait(task: Task) -> int: ) forward_context = Mock() ucconnector.start_load_kv(forward_context) + assert mock_connector.load.call_count == 1 def test_start_load_kv_invalid_para(self): with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None): @@ -559,6 +564,7 @@ def test_start_load_kv_layerwise_success(self): req_meta1.load_blocks = [ (secrets.token_hex(8), i) for i in range(self.block_number) ] + req_meta1.load_async = False metadata = UCConnectorV1Metadata() metadata.requests = [req_meta1] @@ -575,7 +581,7 @@ def mock_load( ucconnector = self.init_uc(mock_connector, metadata=metadata) forward_context = Mock() ucconnector.start_load_kv(forward_context) - assert mock_connector.load.call_count == 2 * self.num_layers + assert mock_connector.load.call_count == self.num_layers if __name__ == "__main__": diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index dac3d8a9..87c33ec8 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -25,9 +25,10 @@ # import hashlib import pickle +from collections import defaultdict from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union import torch from vllm.config import VllmConfig @@ -98,7 +99,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.request_block_infos: dict[str, RequestBlockInfo] = {} # dump tasks record request -> block -> list[task] self.dump_tasks: dict[str, dict[str, List[Task]]] = {} - self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {} + self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict) self.is_mla = self._vllm_config.model_config.is_deepseek_mla self.num_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config @@ -261,62 +262,43 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.layerwise_load_tasks.clear() self.current_layer = 0 + need_wait_tasks = [] for request in metadata.requests: if not request.load_blocks: continue storage_block_ids = [block[0] for block in request.load_blocks] vllm_block_ids = [block[1] for block in request.load_blocks] - blocks_len = len(storage_block_ids) self._load_req_to_blocks.setdefault(request.request_id, set()).update( vllm_block_ids ) + is_load_async = request.load_async + total_offsets = [] + total_tensors = [] + storage_block_ids = storage_block_ids * (1 if self.is_mla else 2) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name ) - k_task_id = self.connector.load( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] - ) - v_task_id = None - if not self.is_mla: - v_task_id = self.connector.load( - storage_block_ids, - offsets[blocks_len:], - tensors[blocks_len:], - ) - if request.request_id not in self.layerwise_load_tasks: - self.layerwise_load_tasks[request.request_id] = {} - self.layerwise_load_tasks[request.request_id][layer_name] = ( - k_task_id, - v_task_id, + if self.use_layerwise and not is_load_async: + task_id = self.connector.load(storage_block_ids, offsets, tensors) + self.layerwise_load_tasks[request.request_id][layer_name] = task_id + continue + else: + total_offsets.extend(offsets) + total_tensors.extend(tensors) + if total_offsets and total_tensors: + storage_block_ids = storage_block_ids * self.num_layers + task_id = self.connector.load( + storage_block_ids, total_offsets, total_tensors ) - - if request.load_async and request.request_id in self.layerwise_load_tasks: - for _, (k_task, v_task) in self.layerwise_load_tasks[ - request.request_id - ].items(): - if request.request_id not in self._need_load_reqs: - self._need_load_reqs[request.request_id] = [] - self._need_load_reqs[request.request_id].append(k_task) - if not self.is_mla: - self._need_load_reqs[request.request_id].append(v_task) - self.layerwise_load_tasks.pop(request.request_id) - continue - - if ( - not self.use_layerwise - and request.request_id in self.layerwise_load_tasks - ): - for _, (k_task, v_task) in self.layerwise_load_tasks[ - request.request_id - ].items(): - if self.connector.wait(k_task) != 0: - self._load_failed_reqs.add(request.request_id) - break - if v_task and self.connector.wait(v_task) != 0: - self._load_failed_reqs.add(request.request_id) - break + if is_load_async: + self._need_load_reqs[request.request_id] = task_id + else: + need_wait_tasks.append(task_id) + for task_id in need_wait_tasks: + if self.connector.wait(task_id) != 0: + self._load_failed_reqs.add(request.request_id) def wait_for_layer_load(self, layer_name: str) -> None: """ @@ -340,20 +322,13 @@ def wait_for_layer_load(self, layer_name: str) -> None: for request_id, layer_to_task in self.layerwise_load_tasks.items(): if request_id in self._load_failed_reqs: continue - k_task, v_task = layer_to_task[layer_name] - if self.connector.wait(k_task) != 0: + task = layer_to_task[layer_name] + if self.connector.wait(task) != 0: self._load_failed_reqs.add(request_id) logger.error( f"Failed to load block for request {request_id} on layer {layer_name}" ) continue - if not self.is_mla: - if self.connector.wait(v_task) != 0: - self._load_failed_reqs.add(request_id) - logger.error( - f"Failed to load block for request {request_id} on layer {layer_name}" - ) - continue logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.") def save_kv_layer( @@ -437,6 +412,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]: """ if hasattr(self, "kv_role") and self.kv_role == "kv_consumer": return + if self.is_mla and self.rank != 0: + return # request id -> succeed dumped blocks success_dumped_blocks: dict[str, list[str]] = {} @@ -455,36 +432,34 @@ def wait_for_tasks(): self.dump_tasks.clear() return success_dumped_blocks if success_dumped_blocks else None + req_to_dump_blocks: dict[str, list[str]] = {} + need_dump_tasks: dict[str, Task] = {} for request in metadata.requests: if not request.dump_blocks: continue storage_block_ids = [block[0] for block in request.dump_blocks] vllm_block_ids = [block[1] for block in request.dump_blocks] - blocks_len = len(storage_block_ids) + req_to_dump_blocks[request.request_id] = storage_block_ids + total_offsets = [] + total_tensors = [] + total_block_ids = ( + storage_block_ids * (1 if self.is_mla else 2) * self.num_layers + ) for layer_name, kv_layer in self.kv_caches.items(): tensors, offsets = self.get_tensor_and_offset_layerwise( vllm_block_ids, kv_layer, layer_name ) - for block_id, offset, tensor in zip( - storage_block_ids, offsets[:blocks_len], tensors[:blocks_len] - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - if not self.is_mla: - for block_id, offset, tensor in zip( - storage_block_ids, - offsets[blocks_len:], - tensors[blocks_len:], - ): - task = self.connector.dump([block_id], [offset], [tensor]) - self.dump_tasks.setdefault(request.request_id, {}).setdefault( - block_id, [] - ).append(task) - wait_for_tasks() - self.dump_tasks.clear() + total_offsets.extend(offsets) + total_tensors.extend(tensors) + task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors) + need_dump_tasks[request.request_id] = task_id + + for req_id, task_id in need_dump_tasks.items(): + if self.connector.wait(task_id) != 0: + logger.error(f"Failed to dump blocks for req {request.request_id}") + else: + success_dumped_blocks[req_id] = req_to_dump_blocks[req_id] return success_dumped_blocks if success_dumped_blocks else None def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: