Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _copy_so_files(self, ext: CMakeExtension):
build_install_dir = STORE_INSTALL_DIR
else:
install_dir = GSA_INSTALL_DIR
build_install_dir = "ucm_sparse"
build_install_dir = "ucm/ucm_sparse"

for so_file in so_files:
src_path = os.path.join(so_search_dir, so_file)
Expand Down
7 changes: 0 additions & 7 deletions test/test_uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,12 @@ def setUp(self):
self.total_blocks_num = 40
self.total_tp_size = 2
self.kv_caches = {}
self.k_data_offsets = {}
for i in range(self.num_layers):
layer_name = f"model.layers.{i}.self_attn.attn"
kv_tensor = torch.rand(
(2, self.total_blocks_num, self.block_size, 4, 8), dtype=torch.bfloat16
)
self.kv_caches[layer_name] = kv_tensor
for layer_id in range(self.num_layers):
self.k_data_offsets[layer_id] = {}
for i in range(self.total_tp_size):
self.k_data_offsets[layer_id][i] = 0

def init_uc(
self, mock_connector, metadata=Mock(), use_layerwise=True
Expand All @@ -116,8 +111,6 @@ def init_uc(
ucconnector._need_load_reqs: dict[str, Union[list[int], list[Task]]] = {}
ucconnector._load_failed_reqs: set[str] = set()
ucconnector._load_req_to_blocks: dict[str, set[int]] = {}
ucconnector.k_data_offsets = self.k_data_offsets
ucconnector.min_block_size = 0
return ucconnector

def test_get_num_new_matched_tokens_hit_all_on_storage(self):
Expand Down
65 changes: 31 additions & 34 deletions ucm/integration/vllm/uc_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
vllm_config.parallel_config
)
self.head_size = vllm_config.model_config.get_head_size()
if role == KVConnectorRole.WORKER:
self._initialize_dataoffset(vllm_config)
if (
self._vllm_config.kv_transfer_config is not None
and "ucm_connector_name"
Expand Down Expand Up @@ -176,35 +174,37 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
forward_context.virtual_engine
]

def _initialize_dataoffset(self, vllm_config: "VllmConfig"):
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
self.min_block_size = (
self.block_size * num_kv_heads * head_size * self.element_size
def DataOffset(self, kv_layer, rank, layer_id, is_v):
# Non-MLA scene: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size)
# MLA scene: one layer shape is (num_blocks, block_size, head_size)
# Element size
elem_size = kv_layer[0].element_size()
logger.debug(
f"total_tp_size = {self.total_tp_size},\n" f"element size = {elem_size}."
)
# One block size
k_min_data_block_size = (
kv_layer[0][0].numel() if not self.is_mla else kv_layer[0].numel()
) * elem_size
v_min_data_block_size = (
kv_layer[1][0].numel() if not self.is_mla else 0
) * elem_size
# When tp > 1 layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size
layer_size = (
self.min_block_size * 2 * self.total_tp_size
if not self.is_mla
else self.min_block_size
)
# layer_id -> rank -> k_offset
self.k_data_offsets: dict[int, dict[int, int]] = {}

pp_size = vllm_config.parallel_config.pipeline_parallel_size
for layer_id in range(self.num_layers * pp_size):
self.k_data_offsets[layer_id] = {}
for rank in range(self.total_tp_size):
if self.is_mla:
self.k_data_offsets[layer_id][0] = layer_size * layer_id
break
else:
offset = (
layer_size * layer_id
+ (layer_size // self.total_tp_size) * rank
)
self.k_data_offsets[layer_id][rank] = offset
k_min_data_block_size + v_min_data_block_size
) * self.total_tp_size
if is_v:
# Offset of v = Offset of k + k_min_data_block_size
return int(
self.DataOffset(kv_layer, rank, layer_id, False) + k_min_data_block_size
)
if self.is_mla:
return int(layer_size * layer_id)
else:
# Offset of k = layer_size * layer_id + layer_size / tp_size * current rank
return int(
layer_size * layer_id + layer_size / self.total_tp_size * self.rank
)

def get_tensor_and_offset_layerwise(
self, vllm_block_ids: List[int], kv_layer: torch.Tensor, layer_name: str
Expand All @@ -216,17 +216,14 @@ def get_tensor_and_offset_layerwise(
layer_id = self._extract_layer_index(layer_name)

for blk_id in vllm_block_ids:
k_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, False)
if self.is_mla:
k_data_offset = self.k_data_offsets[layer_id][0]
k_tensors.append(kv_layer[blk_id])
else:
k_data_offset = self.k_data_offsets[layer_id][self.rank]
k_tensors.append(kv_layer[0][blk_id])
k_offsets.append(k_data_offset)
if not self.is_mla:
v_data_offset = (
self.k_data_offsets[layer_id][self.rank] + self.min_block_size
)
v_data_offset = self.DataOffset(kv_layer, self.rank, layer_id, True)
v_tensors.append(kv_layer[1][blk_id])
v_offsets.append(v_data_offset)
return k_tensors + v_tensors, k_offsets + v_offsets
Expand Down
6 changes: 2 additions & 4 deletions ucm/ucm_sparse/gsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
self.prefetch_engine = GSAPrefetchBase(
vllm_config, 16, True, True, False, 1
)
self.topk_kpre_manger = TopKAndKpreManger(
vllm_config.scheduler_config.max_num_seqs
)
self.topk_kpre_manger = TopKAndKpreManger(MAX_BS)
self.k_cache = {}
self.v_cache = {}
self.tasks_dump = {}
Expand Down Expand Up @@ -504,7 +502,7 @@ def init_topk_cal(
self.gsa_q_cache = torch.zeros(
(
self.layer_num,
vllm_config.scheduler_config.max_num_seqs,
MAX_BS,
att_num_heads,
head_size,
),
Expand Down
11 changes: 8 additions & 3 deletions ucm/ucm_sparse/prefetch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,14 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp):
for index, topk_info in enumerate(self.topk_bs):
if topk_info[1]:
if topk_info[0] in gsa_metadata.gsa_stats:
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
self.topk_buf_tmp[:, index, : topk_info[2]].clone()
)
if not self.is_cpu_topk:
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
self.topk_buf_tmp[:, index, : topk_info[2]].cpu()
)
else:
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
self.topk_buf_tmp[:, index, : topk_info[2]].clone()
)
self.topk_bs = []
for index, req_id in enumerate(self.req_ids_bs):
one_block_len = len(gsa_metadata.gsa_stats[req_id].blocks)
Expand Down