Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 104 additions & 8 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.parallel_state import get_world_group
from vllm.distributed.parallel_state import get_tp_group, get_world_group
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request

Expand Down Expand Up @@ -108,6 +108,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self.block_size = self._vllm_config.cache_config.block_size
self.is_mla = self._vllm_config.model_config.is_deepseek_mla

self.load_only_first_rank = self.is_mla
if self.is_mla:
if role == KVConnectorRole.WORKER:
self.group_coordinator = get_tp_group()
self.broadcast_fn = self.group_coordinator.broadcast
self.broadcast_stream = torch.cuda.Stream()
self.store: UcmKVStoreBase

self.request_hasher = RequestHasher()
Expand Down Expand Up @@ -337,14 +343,88 @@ def _generate_task(
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
return func(ucm_total_block_ids, ucm_offsets, dst_tensor_addr)

def _generate_load_task_for_broadcast(
self,
vllm_block_ids,
ucm_block_ids,
can_load: bool,
) -> tuple[Task, dict[str, torch.Tensor], int]:
"""
Load or Dump func is only called in rank 0 in MLA;
In rank != 0, worker will receive broadcast tensors from rank 0.
"""
layer_to_tensors = {}
total_block_num = len(ucm_block_ids)
dst_tensor_addr, ucm_offsets = [], []
for layer_name, one_layer_kv_cache in self.kv_caches.items():
addrs, offsets = self._get_tensor_and_offset(
vllm_block_ids, one_layer_kv_cache, layer_name
)
layer_to_tensors[layer_name] = addrs[:total_block_num]
dst_tensor_addr.extend(addrs)
ucm_offsets.extend(offsets)
ucm_total_block_ids = ucm_block_ids * len(self.kv_caches)

task = None
if can_load:
assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr)
task = self.store.load(ucm_total_block_ids, ucm_offsets, dst_tensor_addr)
return task, layer_to_tensors, total_block_num

def _broadcast_or_receive_blocks(
self, layer_to_tensors: dict[str : torch.Tensor], total_block_num
):
receive_dict = {}
for layer_name, kv_layer in self.kv_caches.items():
k_tensors = layer_to_tensors[layer_name][:total_block_num]
if self.rank == 0:
tensor_to_broadcast = torch.stack(k_tensors, dim=0)
self.broadcast_fn(tensor_to_broadcast, 0)
else:
shape = (len(k_tensors),) + k_tensors[0].shape
dtype = k_tensors[0].dtype
rec_tensor = torch.empty(shape, dtype=dtype, device=f"cuda:{self.rank}")
self.broadcast_fn(rec_tensor, 0)
receive_dict[layer_name] = rec_tensor
return receive_dict

def _wait_for_broadcast(
self,
req_id: str,
task: Task,
layer_to_tensors: dict[str, torch.Tensor],
total_block_num: int,
):
if self.rank == 0:
if self.store.wait(task) != 0:
logger.error(f"request {req_id} load kv cache failed.")
return
logger.debug(
f"request {req_id} load {total_block_num} blocks on rank {self.rank}"
)
with torch.cuda.stream(self.broadcast_stream):
receive_dict = self._broadcast_or_receive_blocks(
layer_to_tensors, total_block_num
)
self.broadcast_stream.synchronize()
if self.rank > 0 and receive_dict:
for layer_name, kv_layer in self.kv_caches.items():
received_tensor = receive_dict[layer_name]
for i in range(total_block_num):
layer_to_tensors[layer_name][i].copy_(received_tensor[i])
logger.debug(
f"request {req_id} receive broadcast {total_block_num} blocks on rank {self.rank}"
)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:

metadata = self._get_connector_metadata()
assert isinstance(metadata, UCMConnectorMetadata)

self._init_kv_caches_from_forward_context(forward_context)

request_to_task: dict[str, Task] = {}
request_to_task: dict[str, Optional[Task]] = {}
req_to_layer = {}
for request_id, request in metadata.request_meta.items():
hbm_hit_block_num = request.hbm_hit_block_num
total_hit_block_num = request.total_hit_block_num
Expand All @@ -356,13 +436,29 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
hbm_hit_block_num:total_hit_block_num
]
ucm_block_ids = request.ucm_block_ids[hbm_hit_block_num:total_hit_block_num]
request_to_task[request_id] = self._generate_task(
vllm_block_ids, ucm_block_ids, self.store.load
)
if self.load_only_first_rank:
can_load = self.rank == 0
task, layer_to_tensors, total_block_num = (
self._generate_load_task_for_broadcast(
vllm_block_ids, ucm_block_ids, can_load
)
)
req_to_layer[request_id] = (layer_to_tensors, total_block_num)
else:
task = self._generate_task(
vllm_block_ids, ucm_block_ids, self.store.load
)
request_to_task[request_id] = task

for req_id, task in request_to_task.items():
if self.store.wait(task) != 0:
logger.error(f"request {req_id} load kv cache failed.")
if self.load_only_first_rank:
layer_to_tensors, total_block_num = req_to_layer[req_id]
self._wait_for_broadcast(
req_id, task, layer_to_tensors, total_block_num
)
else:
if self.store.wait(task) != 0:
logger.error(f"request {req_id} load kv cache failed.")

def wait_for_layer_load(self, layer_name: str) -> None:
pass
Expand All @@ -378,7 +474,7 @@ def save_kv_layer(

def wait_for_save(self) -> None:

if self.is_mla and self.rank != 0:
if self.load_only_first_rank and self.rank != 0:
return

metadata = self._get_connector_metadata()
Expand Down