diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 5a667c91..3c1ade22 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -10,7 +10,7 @@ KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.parallel_state import get_world_group +from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request @@ -108,6 +108,12 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla + self.load_only_first_rank = self.is_mla + if self.is_mla: + if role == KVConnectorRole.WORKER: + self.group_coordinator = get_tp_group() + self.broadcast_fn = self.group_coordinator.broadcast + self.broadcast_stream = torch.cuda.Stream() self.store: UcmKVStoreBase self.request_hasher = RequestHasher() @@ -337,6 +343,79 @@ def _generate_task( assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) return func(ucm_total_block_ids, ucm_offsets, dst_tensor_addr) + def _generate_load_task_for_broadcast( + self, + vllm_block_ids, + ucm_block_ids, + can_load: bool, + ) -> tuple[Task, dict[str, torch.Tensor], int]: + """ + Load or Dump func is only called in rank 0 in MLA; + In rank != 0, worker will receive broadcast tensors from rank 0. + """ + layer_to_tensors = {} + total_block_num = len(ucm_block_ids) + dst_tensor_addr, ucm_offsets = [], [] + for layer_name, one_layer_kv_cache in self.kv_caches.items(): + addrs, offsets = self._get_tensor_and_offset( + vllm_block_ids, one_layer_kv_cache, layer_name + ) + layer_to_tensors[layer_name] = addrs[:total_block_num] + dst_tensor_addr.extend(addrs) + ucm_offsets.extend(offsets) + ucm_total_block_ids = ucm_block_ids * len(self.kv_caches) + + task = None + if can_load: + assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) + task = self.store.load(ucm_total_block_ids, ucm_offsets, dst_tensor_addr) + return task, layer_to_tensors, total_block_num + + def _broadcast_or_receive_blocks( + self, layer_to_tensors: dict[str : torch.Tensor], total_block_num + ): + receive_dict = {} + for layer_name, kv_layer in self.kv_caches.items(): + k_tensors = layer_to_tensors[layer_name][:total_block_num] + if self.rank == 0: + tensor_to_broadcast = torch.stack(k_tensors, dim=0) + self.broadcast_fn(tensor_to_broadcast, 0) + else: + shape = (len(k_tensors),) + k_tensors[0].shape + dtype = k_tensors[0].dtype + rec_tensor = torch.empty(shape, dtype=dtype, device=f"cuda:{self.rank}") + self.broadcast_fn(rec_tensor, 0) + receive_dict[layer_name] = rec_tensor + return receive_dict + + def _wait_for_broadcast( + self, + req_id: str, + task: Task, + layer_to_tensors: dict[str, torch.Tensor], + total_block_num: int, + ): + if self.rank == 0: + if self.store.wait(task) != 0: + logger.error(f"request {req_id} load kv cache failed.") + return + logger.debug( + f"request {req_id} load {total_block_num} blocks on rank {self.rank}" + ) + with torch.cuda.stream(self.broadcast_stream): + receive_dict = self._broadcast_or_receive_blocks( + layer_to_tensors, total_block_num + ) + self.broadcast_stream.synchronize() + if self.rank > 0 and receive_dict: + for layer_name, kv_layer in self.kv_caches.items(): + received_tensor = receive_dict[layer_name] + for i in range(total_block_num): + layer_to_tensors[layer_name][i].copy_(received_tensor[i]) + logger.debug( + f"request {req_id} receive broadcast {total_block_num} blocks on rank {self.rank}" + ) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() @@ -344,7 +423,8 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self._init_kv_caches_from_forward_context(forward_context) - request_to_task: dict[str, Task] = {} + request_to_task: dict[str, Optional[Task]] = {} + req_to_layer = {} for request_id, request in metadata.request_meta.items(): hbm_hit_block_num = request.hbm_hit_block_num total_hit_block_num = request.total_hit_block_num @@ -356,13 +436,29 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: hbm_hit_block_num:total_hit_block_num ] ucm_block_ids = request.ucm_block_ids[hbm_hit_block_num:total_hit_block_num] - request_to_task[request_id] = self._generate_task( - vllm_block_ids, ucm_block_ids, self.store.load - ) + if self.load_only_first_rank: + can_load = self.rank == 0 + task, layer_to_tensors, total_block_num = ( + self._generate_load_task_for_broadcast( + vllm_block_ids, ucm_block_ids, can_load + ) + ) + req_to_layer[request_id] = (layer_to_tensors, total_block_num) + else: + task = self._generate_task( + vllm_block_ids, ucm_block_ids, self.store.load + ) + request_to_task[request_id] = task for req_id, task in request_to_task.items(): - if self.store.wait(task) != 0: - logger.error(f"request {req_id} load kv cache failed.") + if self.load_only_first_rank: + layer_to_tensors, total_block_num = req_to_layer[req_id] + self._wait_for_broadcast( + req_id, task, layer_to_tensors, total_block_num + ) + else: + if self.store.wait(task) != 0: + logger.error(f"request {req_id} load kv cache failed.") def wait_for_layer_load(self, layer_name: str) -> None: pass @@ -378,7 +474,7 @@ def save_kv_layer( def wait_for_save(self) -> None: - if self.is_mla and self.rank != 0: + if self.load_only_first_rank and self.rank != 0: return metadata = self._get_connector_metadata()