From 2a3aea4a02a56ab7fa90e80fc341f190d3eee586 Mon Sep 17 00:00:00 2001 From: hek14 <1023129548@qq.com> Date: Fri, 26 Sep 2025 09:48:18 +0800 Subject: [PATCH] [fix] ktc config --- examples/offline_inference.py | 1 + .../vllm/patch/0.9.2/vllm-adapt.patch | 250 +++++++++--------- 2 files changed, 119 insertions(+), 132 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 426b7b23..13f320f9 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -49,6 +49,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str): llm_args = EngineArgs( model=model, + kv_transfer_config=ktc, max_model_len=32768, gpu_memory_utilization=0.8, max_num_batched_tokens=30000, diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch index 84f5e4c8..2ef832ac 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch @@ -1,29 +1,29 @@ -From 9124f6f48b958f2535702d8093495097257a2ccc Mon Sep 17 00:00:00 2001 -From: wenxinwang -Date: Thu, 25 Sep 2025 05:03:42 -0700 +From 555ba9e4920445381aecda262b9146342e92eeee Mon Sep 17 00:00:00 2001 +From: hek14 <1023129548@qq.com> +Date: Fri, 26 Sep 2025 09:51:36 +0800 Subject: [PATCH] UCM adaptor --- vllm/attention/layer.py | 45 ++++- .../kv_transfer/kv_connector/utils.py | 113 ++++++++++++ - .../kv_transfer/kv_connector/v1/base.py | 11 +- + .../kv_transfer/kv_connector/v1/base.py | 9 + .../v1/shared_storage_connector.py | 7 +- vllm/v1/core/block_pool.py | 2 +- vllm/v1/core/kv_cache_manager.py | 11 +- vllm/v1/core/sched/output.py | 3 + - vllm/v1/core/sched/scheduler.py | 165 +++++++++++++++++- + vllm/v1/core/sched/scheduler.py | 164 +++++++++++++++++- vllm/v1/core/single_type_kv_cache_manager.py | 3 + vllm/v1/executor/multiproc_executor.py | 30 +++- vllm/v1/outputs.py | 5 + vllm/v1/request.py | 2 +- vllm/v1/worker/block_table.py | 13 ++ vllm/v1/worker/gpu_input_batch.py | 9 + - vllm/v1/worker/gpu_model_runner.py | 122 +++++++++++-- + vllm/v1/worker/gpu_model_runner.py | 120 +++++++++++-- vllm/v1/worker/gpu_worker.py | 25 ++- - 16 files changed, 526 insertions(+), 40 deletions(-) + 16 files changed, 524 insertions(+), 37 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..89b3da489 100644 +index f0ad68b16..2acde35d8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,6 @@ @@ -39,8 +39,8 @@ index f0ad68b16..89b3da489 100644 from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.utils import validate_kv_sharing_target +from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse - - + + class Attention(nn.Module): @@ -409,9 +409,10 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] @@ -53,7 +53,7 @@ index f0ad68b16..89b3da489 100644 + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output - + @@ -449,6 +450,7 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] @@ -69,8 +69,8 @@ index f0ad68b16..89b3da489 100644 - + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) - - + + @@ -479,3 +481,40 @@ direct_register_custom_op( fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, @@ -113,7 +113,7 @@ index f0ad68b16..89b3da489 100644 + + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index 5cbc8ca31..0fee7e74c 100644 +index 5cbc8ca31..8556a979e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,12 +3,18 @@ @@ -126,15 +126,15 @@ index 5cbc8ca31..0fee7e74c 100644 +from typing import Optional, cast + import torch - + import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput - + logger = init_logger(__name__) - + @@ -107,3 +113,110 @@ def get_kv_connector_cache_layout(): "layout to HND for better xfer performance.") return "HND" @@ -247,15 +247,13 @@ index 5cbc8ca31..0fee7e74c 100644 + + return result_future diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -index f80b5eba2..61424b10d 100644 +index f80b5eba2..8891246e6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -@@ -200,7 +200,16 @@ class KVConnectorBase_V1(ABC): - call to this method (this call or a prior one). +@@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC): """ return None, None -- -+ + + def get_block_ids_with_load_errors(self) -> Optional[set[int]]: + """ + Get the set of block IDs that failed to load. @@ -279,10 +277,10 @@ index 3c574d065..223106def 100644 -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING - + import safetensors @@ -53,10 +53,7 @@ class ReqMeta: - + @dataclass class SharedStorageConnectorMetadata(KVConnectorMetadata): - requests: list[ReqMeta] @@ -290,7 +288,7 @@ index 3c574d065..223106def 100644 - def __init__(self): - self.requests = [] + requests: list[ReqMeta] = field(default_factory=list) - + def add_request( self, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py @@ -311,12 +309,12 @@ index 6937455e7..c36a25bc5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -3,7 +3,7 @@ - + from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union - + from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, @@ -325,9 +323,9 @@ index 6937455e7..c36a25bc5 100644 from vllm.v1.request import Request, RequestStatus +from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse +from ucm.integration.vllm.ucm_sparse.base import INVALID_SLOT - + logger = init_logger(__name__) - + @@ -193,6 +195,7 @@ class KVCacheManager: num_draft_tokens: int = 0, num_lookahead_tokens: int = 0, @@ -335,7 +333,7 @@ index 6937455e7..c36a25bc5 100644 + num_slots_sparsed: Union[None, int] = None ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. - + @@ -231,6 +234,12 @@ class KVCacheManager: """ if num_new_tokens == 0: @@ -346,22 +344,22 @@ index 6937455e7..c36a25bc5 100644 + self.coordinator, + self.block_pool, + self.kv_cache_config.kv_cache_groups) - + if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..a0ab878a5 100644 +index d34f39327..141d750b3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -155,3 +155,6 @@ class SchedulerOutput: - + # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None + + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..cb6f44227 100644 +index fe552db74..6a9d4b4b9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput @@ -370,9 +368,9 @@ index fe552db74..cb6f44227 100644 from vllm.v1.structured_output import StructuredOutputManager +from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse +from ucm.integration.vllm.ucm_sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT - + logger = init_logger(__name__) - + @@ -79,12 +81,18 @@ class Scheduler(SchedulerInterface): # will have a corresponding KVConnector with Role=WORKER. # KV Connector pushes/pull of remote KVs for P/D and offloading. @@ -389,11 +387,11 @@ index fe552db74..cb6f44227 100644 + ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) + self.ucm_sparse = get_ucm_sparse() + logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) - + self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface): - + # First, schedule the RUNNING requests. req_index = 0 + req_sparsed_slots: dict[str, int] = {} @@ -403,7 +401,7 @@ index fe552db74..cb6f44227 100644 + if self.ucm_sparse: + num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) + req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - + num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) @@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface): @@ -418,13 +416,13 @@ index fe552db74..cb6f44227 100644 # Preempt the lowest-priority request. @@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface): break - + request = self.waiting.peek_request() + num_slots_sparsed = INVALID_SLOT + if self.ucm_sparse: + num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) + req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - + # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface): @@ -443,23 +441,23 @@ index fe552db74..cb6f44227 100644 # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -745,23 +765,38 @@ class Scheduler(SchedulerInterface): +@@ -745,16 +765,31 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + invalid_block_ids = model_runner_output.invalid_block_ids - + new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None -+ + + recovered_req_ids = None + if invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + recovered_req_ids = self._handle_invalid_blocks(invalid_block_ids) - ++ # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. @@ -475,28 +473,21 @@ index fe552db74..cb6f44227 100644 num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. - new_running.append(request) - continue - -- req_index = model_runner_output.req_id_to_index[req_id] -+ req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] - -@@ -792,6 +827,12 @@ class Scheduler(SchedulerInterface): +@@ -792,6 +827,13 @@ class Scheduler(SchedulerInterface): new_token_ids = generated_token_ids kv_transfer_params = None - + + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) -+ -+ if request.num_output_tokens == 0 and (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens): -+ self.connector.connector.commit(request.succeed_dumped_blocks, True) ++ is_prefill = request.num_output_tokens == 0 ++ is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens) ++ if is_prefill and is_last_chunk: ++ self.connector.connector.commit(request.succeed_dumped_blocks, True) + # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner # to return empty token ids for the request. -@@ -842,7 +883,6 @@ class Scheduler(SchedulerInterface): +@@ -842,7 +884,6 @@ class Scheduler(SchedulerInterface): spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] @@ -504,15 +495,15 @@ index fe552db74..cb6f44227 100644 # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ -@@ -869,6 +909,7 @@ class Scheduler(SchedulerInterface): - +@@ -869,6 +910,7 @@ class Scheduler(SchedulerInterface): + if not stopped: new_running.append(request) + self.running = new_running - + # KV Connector: update state for finished KV Transfers. -@@ -927,6 +968,8 @@ class Scheduler(SchedulerInterface): +@@ -927,6 +969,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request @@ -520,17 +511,17 @@ index fe552db74..cb6f44227 100644 + self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) - -@@ -976,6 +1019,8 @@ class Scheduler(SchedulerInterface): - + +@@ -976,6 +1020,8 @@ class Scheduler(SchedulerInterface): + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() + if self.ucm_sparse: + self.ucm_sparse.request_finished_in_scheduler(request.request_id) - + delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) -@@ -1113,3 +1158,117 @@ class Scheduler(SchedulerInterface): +@@ -1113,3 +1159,117 @@ class Scheduler(SchedulerInterface): for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) @@ -649,19 +640,19 @@ index fe552db74..cb6f44227 100644 + # update_from_output. + return {r.request_id for r in affected_requests} diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py -index 5b4718038..d97690ae5 100644 +index 5b4718038..28bd4618a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py -@@ -141,6 +141,9 @@ class SingleTypeKVCacheManager(ABC): - """ +@@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC): num_cached_blocks = self.num_cached_block[request.request_id] num_full_blocks = num_tokens // self.block_size -+ + + if num_cached_blocks >= num_full_blocks: + return - ++ self.block_pool.cache_full_blocks( request=request, + blocks=self.req_to_blocks[request.request_id], diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b06b7cc80..61cd7110f 100644 --- a/vllm/v1/executor/multiproc_executor.py @@ -681,12 +672,12 @@ index b06b7cc80..61cd7110f 100644 + # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( max_workers=1, thread_name_prefix="mp_exec_io") - + self.output_rank = self._get_output_rank() + self.has_connector = self.vllm_config.kv_transfer_config is not None + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) - + def start_worker_monitor(self): workers = self.workers @@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): @@ -721,11 +712,11 @@ index b06b7cc80..61cd7110f 100644 + return self.kv_output_aggregator.async_aggregate( + outputs, self.output_rank) + return self.kv_output_aggregator.aggregate(outputs, self.output_rank) - + def collective_rpc(self, method: Union[str, Callable], diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..8697150b2 100644 +index f78623f57..c7b4100e3 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -107,6 +107,11 @@ class ModelRunnerOutput: @@ -737,7 +728,7 @@ index f78623f57..8697150b2 100644 + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them. + invalid_block_ids: Optional[set[int]] = None - + # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None diff --git a/vllm/v1/request.py b/vllm/v1/request.py @@ -754,14 +745,13 @@ index 9b96f4599..825b77bba 100644 # indicates that the output is corrupted self.num_nans_in_logits = 0 diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py -index 8f4e8d64c..733ac1f41 100644 +index 8f4e8d64c..f45e39f5c 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py -@@ -60,6 +60,15 @@ class BlockTable: - start = self.num_blocks_per_row[row_idx] +@@ -61,6 +61,15 @@ class BlockTable: self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids -+ + + def reset_row( + self, + row_idx: int, @@ -770,26 +760,27 @@ index 8f4e8d64c..733ac1f41 100644 + self.block_table[row_idx].fill_(0) + self.block_table_cpu[row_idx].fill_(0) + self.block_table_np[row_idx].fill(0) - ++ def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 -@@ -116,6 +125,10 @@ class MultiGroupBlockTable: - row_idx: int) -> None: + self.append_row(block_ids, row_idx) +@@ -117,6 +126,10 @@ class MultiGroupBlockTable: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) -+ + + def reset_row(self, row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.reset_row(row_idx) - ++ def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): + block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be..0e65c98f5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -46,6 +46,11 @@ class CachedRequestState: - + def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) + # 'last_generator_offset' and 'last_gelen_last_output_token_ids' are @@ -797,7 +788,7 @@ index 1a79d72be..0e65c98f5 100644 + # invalid (e.g., due to KV load errors). + self.last_generator_offset = 0 if self.generator else None + self.len_last_output_token_ids = len(self.output_token_ids) - + @property def num_tokens(self) -> int: @@ -201,6 +206,7 @@ class InputBatch: @@ -805,7 +796,7 @@ index 1a79d72be..0e65c98f5 100644 # generator should not be included in the dictionary. self.generators: dict[int, torch.Generator] = {} + self.generators_last_offset: dict[int, int] = {} - + self.num_logprobs: dict[str, int] = {} # NOTE(rob): num_prompt_logprobs only includes reqs @@ -335,6 +341,9 @@ class InputBatch: @@ -815,17 +806,17 @@ index 1a79d72be..0e65c98f5 100644 + assert (request.last_generator_offset is not None) + self.generators_last_offset[ + req_index] = request.last_generator_offset - + if sampling_params.logprobs is not None: self.num_logprobs[req_id] = sampling_params.logprobs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..2538bf0c2 100644 +index 5a26e88db..17b3d1c79 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) - + +from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse +from ucm.integration.vllm.ucm_sparse.base import UcmSparseMetadata, INVALID_SLOT + @@ -851,7 +842,7 @@ index 5a26e88db..2538bf0c2 100644 new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - + # Update the cached states. + if (num_computed_tokens <= req_state.num_computed_tokens): + # The request was rescheduled after a KV load failure. Clear @@ -872,19 +863,18 @@ index 5a26e88db..2538bf0c2 100644 + self.input_batch.num_tokens_no_spec[req_index] = end_idx + req_state.num_computed_tokens = num_computed_tokens - + if not is_last_rank: -@@ -492,17 +516,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): - elif num_new_tokens > 0: +@@ -493,16 +517,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) -+ + + req_state.len_last_output_token_ids = len( + req_state.output_token_ids) + if req_state.generator: + req_state.last_generator_offset = ( + req_state.generator.get_offset()) - ++ # Update the block IDs. - if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. @@ -901,36 +891,34 @@ index 5a26e88db..2538bf0c2 100644 + for block_ids, new_ids in zip(req_state.block_ids, + new_block_ids): + block_ids.extend(new_ids) - + req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -511,10 +541,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # scheduled in the previous step and needs to be added again. +@@ -512,9 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): req_ids_to_add.append(req_id) continue -+ + + if req_state.generator: + assert (req_state.last_generator_offset is not None) + self.input_batch.generators_last_offset[ + req_index] = req_state.last_generator_offset - ++ # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) self.input_batch.block_table.append_row(new_block_ids, req_index) - + # For the last rank, we don't need to update the token_ids_cpu -@@ -622,7 +659,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) +@@ -623,6 +660,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) -+ + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) - ++ + # TODO: improve performance, no `positions_np.copy()` + sparsed_positions = positions_np.copy() + req_sparsed_slots = scheduler_output.req_sparsed_slots @@ -960,7 +948,7 @@ index 5a26e88db..2538bf0c2 100644 @@ -666,9 +716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - + - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) @@ -969,7 +957,7 @@ index 5a26e88db..2538bf0c2 100644 + is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT + if is_sparsed_request: + self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id] - + # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( @@ -680,6 +732,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): @@ -981,20 +969,18 @@ index 5a26e88db..2538bf0c2 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1370,7 +1424,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1370,6 +1424,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) -- + self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) -+ + model_output = self.model( input_ids=input_ids, - positions=positions, @@ -1378,9 +1433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, ) - + - self.maybe_wait_for_kv_save() + finished_dumping = self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() @@ -1002,7 +988,7 @@ index 5a26e88db..2538bf0c2 100644 finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + invalid_block_ids = self.get_block_ids_with_load_errors() - + if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1474,7 +1532,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): @@ -1022,7 +1008,7 @@ index 5a26e88db..2538bf0c2 100644 + finished_dumping=finished_dumping, + invalid_block_ids = invalid_block_ids ) - + def propose_draft_token_ids( @@ -1693,13 +1754,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) @@ -1030,21 +1016,21 @@ index 5a26e88db..2538bf0c2 100644 self.get_finished_kv_transfers(scheduler_output)) + invalid_block_ids = self.get_block_ids_with_load_errors() + get_kv_transfer_group().clear_connector_metadata() - + - if not finished_sending and not finished_recving: + if not finished_sending and not finished_recving and not invalid_block_ids: return EMPTY_MODEL_RUNNER_OUTPUT - + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving + output.invalid_block_ids = invalid_block_ids return output - + @staticmethod @@ -1719,9 +1783,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_connector.start_load_kv(get_forward_context()) - + @staticmethod - def maybe_wait_for_kv_save() -> None: + def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: @@ -1070,13 +1056,13 @@ index 5a26e88db..2538bf0c2 100644 + return + ucm_sparse = get_ucm_sparse() + ucm_sparse.request_finished_in_worker(request_id) - + @staticmethod def get_finished_kv_transfers( @@ -1732,6 +1815,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - + + def get_block_ids_with_load_errors(self) -> Optional[set[int]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_block_ids_with_load_errors() @@ -1117,9 +1103,9 @@ index 9e7e44d06..d52a49a2e 100644 from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase +from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized - + logger = init_logger(__name__) - + @@ -313,9 +316,22 @@ class Worker(WorkerBase): assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, @@ -1142,17 +1128,17 @@ index 9e7e44d06..d52a49a2e 100644 assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output - + def profile(self, is_start: bool = True): if self.profiler is None: @@ -386,6 +402,7 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) - + ensure_kv_transfer_initialized(vllm_config) + ensure_ucm_sparse_initialized(vllm_config) - - + + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): --- -2.34.1 +-- +2.50.1.windows.1