Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 99 additions & 118 deletions test/test_uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import random
import secrets
import unittest
from typing import List
from typing import List, Union
from unittest.mock import MagicMock, Mock, patch

import torch
Expand All @@ -34,9 +34,9 @@
from vllm.v1.request import Request

from ucm.integration.vllm.uc_connector import (
LoadPara,
BlockOperation,
ReqMeta,
SavePara,
RequestBlockInfo,
UCConnectorV1Metadata,
UnifiedCacheConnectorV1,
)
Expand Down Expand Up @@ -100,10 +100,8 @@ def init_uc(
ucconnector.rank = 1
ucconnector.is_mla = False
ucconnector.connector = mock_connector
ucconnector.load_paras: dict[str, LoadPara] = {}
ucconnector.save_paras: dict[str, SavePara] = {}
ucconnector.request_block_infos: dict[str, RequestBlockInfo] = {}
ucconnector.dump_tasks: dict[str, dict[str, List[Task]]] = {}
ucconnector.load_tasks: dict[str, tuple[Task, Task]] = {}
ucconnector.total_tp_size = 2
ucconnector._connector_metadata = metadata
ucconnector.layerwise_load_tasks: dict[
Expand All @@ -114,14 +112,47 @@ def init_uc(
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
return ucconnector

def test_get_num_new_matched_tokens_hit(self):
def test_get_num_new_matched_tokens_hit_all_on_storage(self):
mock_connector = Mock(spec=UcmKVStoreBase)

def mock_lookup(tokens: List[int]) -> List[bool]:
return [True] * self.block_number

mock_connector.lookup.side_effect = mock_lookup
ucconnector = self.init_uc(mock_connector)

random.seed(20250704)
request1 = make_request(
request_id=1,
prompt_token_ids=random.sample(
range(0, 10000), self.block_number * self.block_size
),
mm_positions=None,
mm_hashes=None,
)

# all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
all_tokens_len = len(request1.all_token_ids)
external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0)
self.assertEqual(external_tokens, all_tokens_len - self.block_size)
self.assertEqual(
ucconnector.request_block_infos[request1.request_id].block_operations,
[
BlockOperation.LOAD,
BlockOperation.LOAD,
BlockOperation.LOAD,
BlockOperation.NONE,
],
)

def test_get_num_new_matched_tokens_partial_hit(self):
mock_connector = Mock(spec=UcmKVStoreBase)

def mock_lookup(tokens: List[int]) -> List[bool]:
return [True, False, True, False]

def mock_create(tokens: List[str]) -> List[int]:
return [1] * self.block_number
return [0, 1, 0]

mock_connector.lookup.side_effect = mock_lookup
mock_connector.create.side_effect = mock_create
Expand All @@ -137,10 +168,60 @@ def mock_create(tokens: List[str]) -> List[int]:
mm_hashes=None,
)

# all block dumped in ssd, external_tokens equals to full tokens num
# all block dumped in ssd, external_tokens equals to full tokens num - self.block_size
all_tokens_len = len(request1.all_token_ids)
external_tokens, _ = ucconnector.get_num_new_matched_tokens(request1, 0)
self.assertEqual(external_tokens, all_tokens_len - self.block_size)
self.assertEqual(external_tokens, self.block_size)
self.assertEqual(
ucconnector.request_block_infos[request1.request_id].block_operations,
[
BlockOperation.LOAD,
BlockOperation.DUMP,
BlockOperation.NONE,
BlockOperation.DUMP,
],
)

def test_get_num_new_matched_tokens_partial_hit_with_preftxcache(self):
mock_connector = Mock(spec=UcmKVStoreBase)

def mock_lookup(tokens: List[int]) -> List[bool]:
return [False, True, False]

def mock_create(tokens: List[str]) -> List[int]:
return [0, 1, 0]

mock_connector.lookup.side_effect = mock_lookup
mock_connector.create.side_effect = mock_create
ucconnector = self.init_uc(mock_connector)

random.seed(20250704)
request1 = make_request(
request_id=1,
prompt_token_ids=random.sample(
range(0, 10000), self.block_number * self.block_size
),
mm_positions=None,
mm_hashes=None,
)

# no block dumped in ssd, external_tokens equals to 0
external_tokens, _ = ucconnector.get_num_new_matched_tokens(
request1, self.block_size
)
self.assertEqual(external_tokens, 0)
self.assertEqual(
ucconnector.request_block_infos[request1.request_id].start_position, 1
)
self.assertEqual(
ucconnector.request_block_infos[request1.request_id].block_operations,
[
BlockOperation.NONE,
BlockOperation.DUMP,
BlockOperation.NONE,
BlockOperation.DUMP,
],
)

def test_get_num_new_matched_tokens_no_hit(self):
mock_connector = Mock(spec=UcmKVStoreBase)
Expand All @@ -149,7 +230,7 @@ def mock_lookup(tokens: List[int]) -> List[bool]:
return [False] * self.block_number

def mock_create(tokens: List[str]) -> List[int]:
return [1] * self.block_number
return [0] * self.block_number

mock_connector.lookup.side_effect = mock_lookup
mock_connector.create.side_effect = mock_create
Expand Down Expand Up @@ -192,15 +273,9 @@ def test_get_num_new_matched_tokens_invalid_para(self):
def test_wait_for_save_not_layerwise_success(self):
req_meta1 = MagicMock(spec=ReqMeta)
req_meta1.request_id = "req1"
req_meta1.save_paras = SavePara(
num_blocks_need_save=self.block_number,
start_save_position=0,
num_blocks_to_save=self.block_number,
)
req_meta1.save_paras.block_hashes = [
secrets.token_hex(8) for _ in range(self.block_number)
req_meta1.dump_blocks = [
(secrets.token_hex(8), i) for i in range(self.block_number)
]
req_meta1.vllm_block_ids = list(range(self.block_number))

metadata = UCConnectorV1Metadata()
metadata.requests = [req_meta1]
Expand Down Expand Up @@ -236,15 +311,10 @@ def test_wait_for_save_not_layerwise_invalid_para(self):
def test_start_load_kv_not_layerwise_success(self):
req_meta1 = MagicMock(spec=ReqMeta)
req_meta1.request_id = "req1"
req_meta1.load_paras = LoadPara(
vllm_cached_tokens=1 * self.block_size,
storage_cached_tokens=self.block_number * self.block_size,
can_load=True,
)
req_meta1.load_paras.block_hashes = [
secrets.token_hex(8) for _ in range(self.block_number)
req_meta1.load_blocks = [
(secrets.token_hex(8), i) for i in range(self.block_number)
]
req_meta1.vllm_block_ids = list(range(self.block_number))
req_meta1.load_async = False

metadata = UCConnectorV1Metadata()
metadata.requests = [req_meta1]
Expand Down Expand Up @@ -282,15 +352,9 @@ def test_start_load_kv_invalid_para(self):
def test_start_load_kv_layerwise_success(self):
req_meta1 = MagicMock(spec=ReqMeta)
req_meta1.request_id = "req1"
req_meta1.load_paras = LoadPara(
vllm_cached_tokens=1 * self.block_size,
storage_cached_tokens=self.block_number * self.block_size,
can_load=True,
)
req_meta1.load_paras.block_hashes = [
secrets.token_hex(8) for _ in range(self.block_number)
req_meta1.load_blocks = [
(secrets.token_hex(8), i) for i in range(self.block_number)
]
req_meta1.vllm_block_ids = list(range(self.block_number))

metadata = UCConnectorV1Metadata()
metadata.requests = [req_meta1]
Expand All @@ -309,89 +373,6 @@ def mock_load(
ucconnector.start_load_kv(forward_context)
assert mock_connector.load.call_count == 2 * self.num_layers

def test_generate_layerwise_load_tasks_success(self):
# init implement
mock_connector = Mock(spec=UcmKVStoreBase)

def mock_load(
block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
) -> Task:
assert offset is not None and offset
assert dst_tensor is not None and dst_tensor
return Task()

mock_connector.load.side_effect = mock_load
ucconnector = self.init_uc(mock_connector)

# provide generate_layerwise_load_tasks params
fetch_block_ids = list(range(self.block_number * 2))
fetch_block_hashes = [
secrets.token_hex(8) for _ in range(self.block_number * 2)
]
layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {}
current_layer = 0
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = ucconnector.get_tensor_and_offset_layerwise(
fetch_block_ids, kv_layer, layer_name
)
layer_to_tensor[layer_name] = (tensors, offsets)
current_layer += 1
# generate layerwise tasks
layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
fetch_block_hashes, layer_to_tensor
)

for i in range(self.num_layers):
task = next(layerwise_load_task)
assert task is not None, f"layer {i} is None"
assert mock_connector.load.call_count == self.num_layers * 2

def test_generate_layerwise_load_tasks_invalid_params(self):
# init implement
mock_connector = Mock(spec=UcmKVStoreBase)

def mock_load(
block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
) -> Task:
assert offset is not None and offset
assert dst_tensor is not None and dst_tensor
return Task()

mock_connector.load.side_effect = mock_load
ucconnector = self.init_uc(mock_connector)

# provide generate_layerwise_load_tasks params
fetch_block_ids = list(range(self.block_number * 2))
fetch_block_hashes = [
secrets.token_hex(8) for _ in range(self.block_number * 2)
]
layer_to_tensor: dict[str, tuple[List[torch.Tensor], List[int]]] = {}
for layer_name, kv_layer in self.kv_caches.items():
tensors, offsets = ucconnector.get_tensor_and_offset_layerwise(
fetch_block_ids, kv_layer, layer_name
)
layer_to_tensor[layer_name] = (tensors, offsets)
# generate layerwise tasks
layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
[], layer_to_tensor
)
with self.assertRaises(AssertionError) as context:
next(layerwise_load_task)
self.assertEqual(
str(context.exception),
"The block hashes need to be fetched should not be None or empty.",
)

layerwise_load_task = ucconnector.generate_layerwise_load_tasks(
fetch_block_hashes, None
)
with self.assertRaises(AssertionError) as context:
next(layerwise_load_task)
self.assertEqual(
str(context.exception),
"The layers of tensor need to be fetched should not be None or empty.",
)


if __name__ == "__main__":
unittest.main()
Loading