diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 9a08b8c0..6c92b70d 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -289,6 +289,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: self._need_load_reqs[request.request_id].append(k_task) if not self.is_mla: self._need_load_reqs[request.request_id].append(v_task) + self.layerwise_load_tasks.pop(request.request_id) continue if ( @@ -570,18 +571,17 @@ def md5(input) -> int: if ( request.kv_transfer_params and request.kv_transfer_params["load_async"] == False - ): + ) or num_lookup_hits == 0: return 0, False request.kv_transfer_params = request.kv_transfer_params or {} request.kv_transfer_params["load_async"] = False - if num_lookup_hits > 0: - self.request_block_infos[request.request_id] = RequestBlockInfo( - block_hashes=block_hashes, - block_operations=block_operations, - start_position=start_position, - ) - self._need_load_reqs[request.request_id] = [] - return num_lookup_hits * self.block_size, True + self.request_block_infos[request.request_id] = RequestBlockInfo( + block_hashes=block_hashes, + block_operations=block_operations, + start_position=start_position, + ) + self._need_load_reqs[request.request_id] = [] + return num_lookup_hits * self.block_size, True # Create blocks for the remaining (unmatched) blocks if num_lookup_hits < len(remain_hashes):