From 2d140fd4e97c6e6ac31c94cf367c2a321b143619 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Fri, 28 Nov 2025 22:14:27 -0800 Subject: [PATCH 1/4] fix accuracy problem when chunked prefill --- ucm/integration/vllm/ucm_connector.py | 38 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index f4b1f4d3..4d5dadb9 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -16,7 +16,7 @@ from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus from ucm.logger import init_logger from ucm.shared.metrics import ucmmonitor @@ -39,6 +39,9 @@ class RequestMeta: hbm_hit_block_num: int = 0 # local_computed_block + external_computed_block total_hit_block_num: int = 0 + num_token_ids: int = 0 + vllm_block_ids: list[int] = field(default_factory=list) + token_processed: int = 0 @dataclass @@ -207,6 +210,8 @@ def get_num_new_matched_tokens( request: "Request", num_computed_tokens: int, ) -> tuple[int, bool]: + if request.status == RequestStatus.PREEMPTED: + self.requests_meta.pop(request.request_id, None) assert num_computed_tokens % self.block_size == 0 hbm_hit_block_num = num_computed_tokens // self.block_size @@ -249,6 +254,7 @@ def get_num_new_matched_tokens( ucm_block_ids=ucm_block_ids, hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, + num_token_ids=len(request.all_token_ids), ) return external_hit_tokens, False @@ -277,22 +283,29 @@ def _generate_dispatch_meta( | scheduled_block_num | """ - new_blocks_num = new_tokens // self.block_size hbm_hit_block_num = req_meta.hbm_hit_block_num total_hit_block_num = req_meta.total_hit_block_num - scheduled_block_num = total_hit_block_num + new_blocks_num ucm_block_ids = req_meta.ucm_block_ids + req_meta.vllm_block_ids.extend(vllm_block_ids) - dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num] if need_load: + new_blocks_num = new_tokens // self.block_size + scheduled_block_num = total_hit_block_num + new_blocks_num + dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num] dump_vllm_block_ids = vllm_block_ids[ total_hit_block_num:scheduled_block_num ] + req_meta.token_processed = ( + new_tokens + self.block_size * total_hit_block_num + ) else: - dump_vllm_block_ids = vllm_block_ids - - # after this round, req_meta will be updated - req_meta.total_hit_block_num = scheduled_block_num + if req_meta.token_processed >= req_meta.num_token_ids: + return RequestDispatchMeta(([], []), ([], [])) + start_idx = req_meta.token_processed // self.block_size + end_idx = (req_meta.token_processed + new_tokens) // self.block_size + dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] + dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] + req_meta.token_processed += new_tokens load_ucm_block_ids, load_vllm_block_ids = [], [] if need_load: @@ -327,15 +340,16 @@ def build_connector_meta( if not isinstance(scheduled_cached_reqs, list): # >= 0.9.2 for i, request_id in enumerate(scheduled_cached_reqs.req_ids): - if scheduler_output.num_scheduled_tokens[request_id] == 1: - # decode stage - continue req_meta = self.requests_meta.get(request_id) if req_meta: + if scheduled_cached_reqs.new_block_ids[i] != None: + new_block_ids = scheduled_cached_reqs.new_block_ids[i][0] + else: + new_block_ids = [] requests_dispatch_meta[request_id] = self._generate_dispatch_meta( req_meta, scheduler_output.num_scheduled_tokens[request_id], - scheduled_cached_reqs.new_block_ids[i][0], + new_block_ids, scheduled_cached_reqs.resumed_from_preemption[i], ) else: From 8c1c3874cadfbd3a8a6de4466e392ae86128b813 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Fri, 28 Nov 2025 22:37:01 -0800 Subject: [PATCH 2/4] remove unnessary code for preempted request --- ucm/integration/vllm/ucm_connector.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 4d5dadb9..d7d664a2 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -16,7 +16,7 @@ from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.platforms import current_platform from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.request import Request, RequestStatus +from vllm.v1.request import Request from ucm.logger import init_logger from ucm.shared.metrics import ucmmonitor @@ -210,9 +210,6 @@ def get_num_new_matched_tokens( request: "Request", num_computed_tokens: int, ) -> tuple[int, bool]: - if request.status == RequestStatus.PREEMPTED: - self.requests_meta.pop(request.request_id, None) - assert num_computed_tokens % self.block_size == 0 hbm_hit_block_num = num_computed_tokens // self.block_size From 30a4246d0d15e20cefc7276356b8133db17ed780 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Fri, 28 Nov 2025 23:15:44 -0800 Subject: [PATCH 3/4] modify dump --- ucm/integration/vllm/ucm_connector.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index d7d664a2..dd9540d4 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -252,6 +252,7 @@ def get_num_new_matched_tokens( hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, num_token_ids=len(request.all_token_ids), + token_processed = self.block_size * total_hit_block_num ) return external_hit_tokens, False @@ -285,24 +286,11 @@ def _generate_dispatch_meta( ucm_block_ids = req_meta.ucm_block_ids req_meta.vllm_block_ids.extend(vllm_block_ids) - if need_load: - new_blocks_num = new_tokens // self.block_size - scheduled_block_num = total_hit_block_num + new_blocks_num - dump_ucm_block_ids = ucm_block_ids[total_hit_block_num:scheduled_block_num] - dump_vllm_block_ids = vllm_block_ids[ - total_hit_block_num:scheduled_block_num - ] - req_meta.token_processed = ( - new_tokens + self.block_size * total_hit_block_num - ) - else: - if req_meta.token_processed >= req_meta.num_token_ids: - return RequestDispatchMeta(([], []), ([], [])) - start_idx = req_meta.token_processed // self.block_size - end_idx = (req_meta.token_processed + new_tokens) // self.block_size - dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] - dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] - req_meta.token_processed += new_tokens + start_idx = req_meta.token_processed // self.block_size + end_idx = (req_meta.token_processed + new_tokens) // self.block_size + dump_ucm_block_ids = ucm_block_ids[start_idx:end_idx] + dump_vllm_block_ids = req_meta.vllm_block_ids[start_idx:end_idx] + req_meta.token_processed += new_tokens load_ucm_block_ids, load_vllm_block_ids = [], [] if need_load: @@ -337,6 +325,9 @@ def build_connector_meta( if not isinstance(scheduled_cached_reqs, list): # >= 0.9.2 for i, request_id in enumerate(scheduled_cached_reqs.req_ids): + if scheduler_output.num_scheduled_tokens[request_id] == 1: + # decode stage + continue req_meta = self.requests_meta.get(request_id) if req_meta: if scheduled_cached_reqs.new_block_ids[i] != None: From 5d90ace53d54112e4e886ed55cc34eabee7df880 Mon Sep 17 00:00:00 2001 From: qyh111 Date: Fri, 28 Nov 2025 23:30:08 -0800 Subject: [PATCH 4/4] fix comment --- ucm/integration/vllm/ucm_connector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index dd9540d4..b3ec3544 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -244,7 +244,8 @@ def get_num_new_matched_tokens( # When all the tokens are cached in ssd or hbm, # we need to recompute the last token. This if condition will be removed # once vLLM scheduler provides a better solution in the future. - if total_hit_block_num * self.block_size == request.num_tokens: + num_total_hit_tokens = total_hit_block_num * self.block_size + if num_total_hit_tokens == request.num_tokens: external_hit_tokens -= 1 self.requests_meta[request.request_id] = RequestMeta( @@ -252,7 +253,7 @@ def get_num_new_matched_tokens( hbm_hit_block_num=hbm_hit_block_num, total_hit_block_num=total_hit_block_num, num_token_ids=len(request.all_token_ids), - token_processed = self.block_size * total_hit_block_num + token_processed=num_total_hit_tokens, ) return external_hit_tokens, False @@ -330,10 +331,9 @@ def build_connector_meta( continue req_meta = self.requests_meta.get(request_id) if req_meta: + new_block_ids = [] if scheduled_cached_reqs.new_block_ids[i] != None: new_block_ids = scheduled_cached_reqs.new_block_ids[i][0] - else: - new_block_ids = [] requests_dispatch_meta[request_id] = self._generate_dispatch_meta( req_meta, scheduler_output.num_scheduled_tokens[request_id],