diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 87c33ec8..882a1e77 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -262,7 +262,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self.layerwise_load_tasks.clear() self.current_layer = 0 - need_wait_tasks = [] + need_load_tasks: dict[str, Task] = {} for request in metadata.requests: if not request.load_blocks: continue @@ -295,10 +295,11 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if is_load_async: self._need_load_reqs[request.request_id] = task_id else: - need_wait_tasks.append(task_id) - for task_id in need_wait_tasks: + need_load_tasks[request.request_id] = task_id + for req_id, task_id in need_load_tasks.items(): if self.connector.wait(task_id) != 0: - self._load_failed_reqs.add(request.request_id) + self._load_failed_reqs.add(req_id) + logger.error(f"Failed to load blocks for req {req_id}") def wait_for_layer_load(self, layer_name: str) -> None: """