Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/test_uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import random
import secrets
import unittest
from collections import defaultdict
from typing import List, Union
from unittest.mock import MagicMock, Mock, patch

Expand Down Expand Up @@ -106,12 +107,14 @@ def init_uc(
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
ucconnector.total_tp_size = self.total_tp_size
ucconnector._connector_metadata = metadata
ucconnector.layerwise_load_tasks: dict[
str, dict[str, tuple[Task, Task]]
] = {}
ucconnector.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(
dict
)
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
ucconnector._load_failed_reqs: set[str] = set()
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
ucconnector.num_layers = 48
ucconnector.is_mla = False
return ucconnector

def test_get_num_new_matched_tokens_hit_all_on_storage(self):
Expand Down Expand Up @@ -508,6 +511,7 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
ucconnector.block_size = self.block_size
ucconnector.use_layerwise = False
ucconnector._connector_metadata = Mock()
ucconnector.is_mla = False

with self.assertRaises(AssertionError):
ucconnector.wait_for_save()
Expand Down Expand Up @@ -542,6 +546,7 @@ def mock_wait(task: Task) -> int:
)
forward_context = Mock()
ucconnector.start_load_kv(forward_context)
assert mock_connector.load.call_count == 1

def test_start_load_kv_invalid_para(self):
with patch.object(UnifiedCacheConnectorV1, "__init__", return_value=None):
Expand All @@ -559,6 +564,7 @@ def test_start_load_kv_layerwise_success(self):
req_meta1.load_blocks = [
(secrets.token_hex(8), i) for i in range(self.block_number)
]
req_meta1.load_async = False

metadata = UCConnectorV1Metadata()
metadata.requests = [req_meta1]
Expand All @@ -575,7 +581,7 @@ def mock_load(
ucconnector = self.init_uc(mock_connector, metadata=metadata)
forward_context = Mock()
ucconnector.start_load_kv(forward_context)
assert mock_connector.load.call_count == 2 * self.num_layers
assert mock_connector.load.call_count == self.num_layers


if __name__ == "__main__":
Expand Down
121 changes: 48 additions & 73 deletions ucm/integration/vllm/uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
#
import hashlib
import pickle
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union

import torch
from vllm.config import VllmConfig
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.request_block_infos: dict[str, RequestBlockInfo] = {}
# dump tasks record request -> block -> list[task]
self.dump_tasks: dict[str, dict[str, List[Task]]] = {}
self.layerwise_load_tasks: dict[str, dict[str, tuple[Task, Task]]] = {}
self.layerwise_load_tasks: dict[str, dict[str, Task]] = defaultdict(dict)
self.is_mla = self._vllm_config.model_config.is_deepseek_mla
self.num_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
Expand Down Expand Up @@ -261,62 +262,43 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:

self.layerwise_load_tasks.clear()
self.current_layer = 0
need_wait_tasks = []
for request in metadata.requests:
if not request.load_blocks:
continue

storage_block_ids = [block[0] for block in request.load_blocks]
vllm_block_ids = [block[1] for block in request.load_blocks]
blocks_len = len(storage_block_ids)
self._load_req_to_blocks.setdefault(request.request_id, set()).update(
vllm_block_ids
)
is_load_async = request.load_async
total_offsets = []
total_tensors = []
storage_block_ids = storage_block_ids * (1 if self.is_mla else 2)
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = self.get_tensor_and_offset_layerwise(
vllm_block_ids, kv_layer, layer_name
)
k_task_id = self.connector.load(
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
)
v_task_id = None
if not self.is_mla:
v_task_id = self.connector.load(
storage_block_ids,
offsets[blocks_len:],
tensors[blocks_len:],
)
if request.request_id not in self.layerwise_load_tasks:
self.layerwise_load_tasks[request.request_id] = {}
self.layerwise_load_tasks[request.request_id][layer_name] = (
k_task_id,
v_task_id,
if self.use_layerwise and not is_load_async:
task_id = self.connector.load(storage_block_ids, offsets, tensors)
self.layerwise_load_tasks[request.request_id][layer_name] = task_id
continue
else:
total_offsets.extend(offsets)
total_tensors.extend(tensors)
if total_offsets and total_tensors:
storage_block_ids = storage_block_ids * self.num_layers
task_id = self.connector.load(
storage_block_ids, total_offsets, total_tensors
)

if request.load_async and request.request_id in self.layerwise_load_tasks:
for _, (k_task, v_task) in self.layerwise_load_tasks[
request.request_id
].items():
if request.request_id not in self._need_load_reqs:
self._need_load_reqs[request.request_id] = []
self._need_load_reqs[request.request_id].append(k_task)
if not self.is_mla:
self._need_load_reqs[request.request_id].append(v_task)
self.layerwise_load_tasks.pop(request.request_id)
continue

if (
not self.use_layerwise
and request.request_id in self.layerwise_load_tasks
):
for _, (k_task, v_task) in self.layerwise_load_tasks[
request.request_id
].items():
if self.connector.wait(k_task) != 0:
self._load_failed_reqs.add(request.request_id)
break
if v_task and self.connector.wait(v_task) != 0:
self._load_failed_reqs.add(request.request_id)
break
if is_load_async:
self._need_load_reqs[request.request_id] = task_id
else:
need_wait_tasks.append(task_id)
for task_id in need_wait_tasks:
if self.connector.wait(task_id) != 0:
self._load_failed_reqs.add(request.request_id)

def wait_for_layer_load(self, layer_name: str) -> None:
"""
Expand All @@ -340,20 +322,13 @@ def wait_for_layer_load(self, layer_name: str) -> None:
for request_id, layer_to_task in self.layerwise_load_tasks.items():
if request_id in self._load_failed_reqs:
continue
k_task, v_task = layer_to_task[layer_name]
if self.connector.wait(k_task) != 0:
task = layer_to_task[layer_name]
if self.connector.wait(task) != 0:
self._load_failed_reqs.add(request_id)
logger.error(
f"Failed to load block for request {request_id} on layer {layer_name}"
)
continue
if not self.is_mla:
if self.connector.wait(v_task) != 0:
self._load_failed_reqs.add(request_id)
logger.error(
f"Failed to load block for request {request_id} on layer {layer_name}"
)
continue
logger.debug(f"Load tasks for {request_id} on layer {layer_name} finished.")

def save_kv_layer(
Expand Down Expand Up @@ -437,6 +412,8 @@ def wait_for_save(self) -> Optional[dict[str, list[str]]]:
"""
if hasattr(self, "kv_role") and self.kv_role == "kv_consumer":
return
if self.is_mla and self.rank != 0:
return
# request id -> succeed dumped blocks
success_dumped_blocks: dict[str, list[str]] = {}

Expand All @@ -455,36 +432,34 @@ def wait_for_tasks():
self.dump_tasks.clear()
return success_dumped_blocks if success_dumped_blocks else None

req_to_dump_blocks: dict[str, list[str]] = {}
need_dump_tasks: dict[str, Task] = {}
for request in metadata.requests:
if not request.dump_blocks:
continue

storage_block_ids = [block[0] for block in request.dump_blocks]
vllm_block_ids = [block[1] for block in request.dump_blocks]
blocks_len = len(storage_block_ids)
req_to_dump_blocks[request.request_id] = storage_block_ids
total_offsets = []
total_tensors = []
total_block_ids = (
storage_block_ids * (1 if self.is_mla else 2) * self.num_layers
)
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = self.get_tensor_and_offset_layerwise(
vllm_block_ids, kv_layer, layer_name
)
for block_id, offset, tensor in zip(
storage_block_ids, offsets[:blocks_len], tensors[:blocks_len]
):
task = self.connector.dump([block_id], [offset], [tensor])
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
block_id, []
).append(task)
if not self.is_mla:
for block_id, offset, tensor in zip(
storage_block_ids,
offsets[blocks_len:],
tensors[blocks_len:],
):
task = self.connector.dump([block_id], [offset], [tensor])
self.dump_tasks.setdefault(request.request_id, {}).setdefault(
block_id, []
).append(task)
wait_for_tasks()
self.dump_tasks.clear()
total_offsets.extend(offsets)
total_tensors.extend(tensors)
task_id = self.connector.dump(total_block_ids, total_offsets, total_tensors)
need_dump_tasks[request.request_id] = task_id

for req_id, task_id in need_dump_tasks.items():
if self.connector.wait(task_id) != 0:
logger.error(f"Failed to dump blocks for req {request.request_id}")
else:
success_dumped_blocks[req_id] = req_to_dump_blocks[req_id]
return success_dumped_blocks if success_dumped_blocks else None

def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
Expand Down