diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 10c6785e..2c8038b6 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -381,6 +381,8 @@ def _generate_task( dst_tensor_addr.extend(addrs) ucm_offsets.extend(offsets) ucm_total_block_ids = ucm_block_ids * len(self.kv_caches) + if not self.is_mla: + ucm_total_block_ids *= 2 assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) return func(ucm_total_block_ids, ucm_offsets, dst_tensor_addr) @@ -532,6 +534,8 @@ def wait_for_save(self) -> None: break end += 1 + if end == len(ucm_block_ids): + continue ucm_block_ids = ucm_block_ids[:end] vllm_block_ids = vllm_block_ids[:end] request_to_task[request_id] = self._generate_task(