diff --git a/docker/Dockerfile b/docker/Dockerfile index 4c76cf16..35c332a5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,16 +1,21 @@ # Set to other image if needed FROM vllm/vllm-openai:v0.9.2 +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" + WORKDIR /workspace # Install unified-cache-management COPY . /vllm-workspace/unified-cache-management +RUN pip config set global.index-url ${PIP_INDEX_URL} + RUN export PLATFORM="cuda" && \ pip install -v -e /vllm-workspace/unified-cache-management # Apply patch for vLLM RUN cd $(pip show vllm | grep Location | awk '{print $2}') \ - && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt.patch + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt.patch \ + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch ENTRYPOINT ["/bin/bash"] \ No newline at end of file diff --git a/docker/Dockerfile-NPU b/docker/Dockerfile-NPU index 83422e59..a50f3596 100644 --- a/docker/Dockerfile-NPU +++ b/docker/Dockerfile-NPU @@ -1,22 +1,27 @@ # Set to other image if needed FROM quay.io/ascend/vllm-ascend:v0.9.2rc1 +ARG PIP_INDEX_URL="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" + WORKDIR /workspace # Install unified-cache-management COPY . /vllm-workspace/unified-cache-management +RUN pip config set global.index-url ${PIP_INDEX_URL} + RUN export PLATFORM="ascend" && \ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \ pip install -v -e /vllm-workspace/unified-cache-management # Apply patch for vLLM RUN cd /vllm-workspace/vllm \ - && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt.patch + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt.patch \ + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch # Apply patch for vLLM-Ascend RUN cd /vllm-workspace/vllm-ascend \ - && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch - + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch \ + && git apply /vllm-workspace/unified-cache-management/unifiedcache/patch/0.9.2/vllm-ascend-adapt-sparse.patch CMD ["/bin/bash"] \ No newline at end of file diff --git a/setup.py b/setup.py index 4a42691e..12fafdc8 100644 --- a/setup.py +++ b/setup.py @@ -117,7 +117,7 @@ def build_cmake(self, ext: CMakeExtension): setup( name="unifiedcache", - version="0.0.1", + version="0.0.2", description="Unified Cache Management", author="Unified Cache Team", packages=find_packages(), diff --git a/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch b/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch index 83fb941c..f2da8748 100644 --- a/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch +++ b/unifiedcache/patch/0.9.2/vllm-adapt-sparse.patch @@ -1,5 +1,20 @@ +From 8cef77f16ca578122d6858d07019d471bf2c00c7 Mon Sep 17 00:00:00 2001 +From: harrisonyhq +Date: Tue, 2 Sep 2025 20:11:39 +0800 +Subject: [PATCH] [Patch] vLLM patch for UCM Sparse + +--- + vllm/attention/layer.py | 42 ++++++++++++++++++ + vllm/v1/core/kv_cache_manager.py | 32 +++++++++++++- + vllm/v1/core/sched/output.py | 3 ++ + vllm/v1/core/sched/scheduler.py | 28 +++++++++++- + vllm/v1/worker/block_table.py | 13 ++++++ + vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++----- + vllm/v1/worker/gpu_worker.py | 2 + + 7 files changed, 177 insertions(+), 13 deletions(-) + diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..db0d8a58d 100644 +index f0ad68b16..847c97371 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod @@ -10,17 +25,18 @@ index f0ad68b16..db0d8a58d 100644 class Attention(nn.Module): -@@ -409,8 +410,10 @@ def unified_attention( +@@ -409,9 +410,11 @@ def unified_attention( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] + maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) return output + @@ -449,6 +452,7 @@ def unified_attention_with_output( attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] @@ -29,16 +45,15 @@ index f0ad68b16..db0d8a58d 100644 self.impl.forward(self, query, key, -@@ -457,7 +461,7 @@ def unified_attention_with_output( - attn_metadata, +@@ -458,6 +462,7 @@ def unified_attention_with_output( output=output, output_scale=output_scale) -- + + maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) maybe_save_kv_layer_to_connector(layer_name, kv_cache) -@@ -479,3 +483,40 @@ direct_register_custom_op( +@@ -479,3 +484,40 @@ direct_register_custom_op( fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) @@ -79,18 +94,19 @@ index f0ad68b16..db0d8a58d 100644 + return + + ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) -\ No newline at end of file diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py -index 6937455e7..764931668 100644 +index 6937455e7..3a44db442 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py -@@ -3,7 +3,8 @@ +@@ -1,9 +1,10 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project ++import math from collections import defaultdict from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union -+import math from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger @@ -111,12 +127,12 @@ index 6937455e7..764931668 100644 ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. -@@ -231,6 +235,31 @@ class KVCacheManager: - """ +@@ -232,6 +236,32 @@ class KVCacheManager: if num_new_tokens == 0: raise ValueError("num_new_tokens must be greater than 0") -+ + + if num_slots_sparsed != INVALID_SLOT: ++ self.block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + num_blocks_need = math.ceil(num_slots_sparsed / self.block_size) + allocated_blocks = self.coordinator.get_blocks(request.request_id)[0] + returned_blocks = [] @@ -140,9 +156,10 @@ index 6937455e7..764931668 100644 + return None + new_blocks = self.coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) + return KVCacheBlocks(tuple([sparsed_blocks])) - ++ if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks + else: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index d34f39327..141d750b3 100644 --- a/vllm/v1/core/sched/output.py @@ -155,7 +172,7 @@ index d34f39327..141d750b3 100644 + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..d16785ce0 100644 +index 22c0ad8d6..c5c39a2b8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -35,6 +35,9 @@ from vllm.v1.request import Request, RequestStatus @@ -239,69 +256,32 @@ index fe552db74..d16785ce0 100644 # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -869,6 +891,9 @@ class Scheduler(SchedulerInterface): - - if not stopped: - new_running.append(request) -+ -+ if model_runner_output.finished_dumping is not None: -+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) - self.running = new_running - - # KV Connector: update state for finished KV Transfers. -@@ -927,6 +952,8 @@ class Scheduler(SchedulerInterface): +@@ -929,6 +951,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request + if self.ucm_sparse: -+ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) ++ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) -@@ -976,7 +1003,8 @@ class Scheduler(SchedulerInterface): - +@@ -979,6 +1003,8 @@ class Scheduler(SchedulerInterface): def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() -- + + if self.ucm_sparse: + self.ucm_sparse.request_finished_in_scheduler(request.request_id) delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) request_id = request.request_id -diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..c8388baed 100644 ---- a/vllm/v1/outputs.py -+++ b/vllm/v1/outputs.py -@@ -107,6 +107,7 @@ class ModelRunnerOutput: - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None -+ finished_dumping: Optional[dict[str, list[str]]] = None - - # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None -diff --git a/vllm/v1/request.py b/vllm/v1/request.py -index 9b96f4599..e70d1695b 100644 ---- a/vllm/v1/request.py -+++ b/vllm/v1/request.py -@@ -103,6 +103,8 @@ class Request: - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 - -+ self.succeed_dumped_blocks: list[str] = [] -+ - # The number of NaNs in logits. A value greater than 0 - # indicates that the output is corrupted - self.num_nans_in_logits = 0 diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py -index 8f4e8d64c..d5be44680 100644 +index 8f4e8d64c..eda1ed2cb 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py -@@ -60,6 +60,15 @@ class BlockTable: - start = self.num_blocks_per_row[row_idx] +@@ -61,6 +61,15 @@ class BlockTable: self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids -+ + + def reset_row( + self, + row_idx: int, @@ -310,9 +290,10 @@ index 8f4e8d64c..d5be44680 100644 + self.block_table[row_idx].fill_(0) + self.block_table_cpu[row_idx].fill_(0) + self.block_table_np[row_idx].fill(0) - ++ def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) @@ -117,6 +126,10 @@ class MultiGroupBlockTable: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) @@ -320,12 +301,12 @@ index 8f4e8d64c..d5be44680 100644 + def reset_row(self, row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.reset_row(row_idx) -+ ++ def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..3cbd79c00 100644 +index 14278bb6a..84a597b0e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager @@ -365,32 +346,34 @@ index 5a26e88db..3cbd79c00 100644 # Update the block IDs. - if not resumed_from_preemption: +- # Append the new blocks to the existing block IDs. +- for block_ids, new_ids in zip(req_state.block_ids, +- new_block_ids): +- block_ids.extend(new_ids) +- else: + if resumed_from_preemption or is_sparsed_request: -+ # The request is resumed from preemption. -+ # Replace the existing block IDs with the new ones. -+ req_state.block_ids = new_block_ids + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids + else: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) -- else: -- # The request is resumed from preemption. -- # Replace the existing block IDs with the new ones. -- req_state.block_ids = new_block_ids ++ # Append the new blocks to the existing block IDs. ++ for block_ids, new_ids in zip(req_state.block_ids, ++ new_block_ids): ++ block_ids.extend(new_ids) req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -515,6 +521,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -515,6 +521,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) + if is_sparsed_request: + self.input_batch.block_table.reset_row(req_index) ++ self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -623,6 +631,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -623,6 +632,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -410,13 +393,12 @@ index 5a26e88db..3cbd79c00 100644 # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -652,11 +673,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -652,11 +674,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) + sparsed_positions // block_size) -+ block_table_cpu = block_table.get_cpu_tensor() block_numbers = block_table_cpu.flatten( )[block_table_indices].numpy() @@ -457,34 +439,18 @@ index 5a26e88db..3cbd79c00 100644 model_output = self.model( input_ids=input_ids, -@@ -1378,7 +1405,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): - inputs_embeds=inputs_embeds, +@@ -1379,6 +1406,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -- self.maybe_wait_for_kv_save() -+ finished_dumping = self.maybe_wait_for_kv_save() + finished_dumping = self.maybe_wait_for_kv_save() + self.maybe_execute_ucm_sparse_finished() + finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) -@@ -1563,6 +1592,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, -+ finished_dumping=finished_dumping - ) - - def propose_draft_token_ids( -@@ -1719,10 +1749,29 @@ class GPUModelRunner(LoRAModelRunnerMixin): - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod -- def maybe_wait_for_kv_save() -> None: -+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: +@@ -1724,6 +1753,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): -- get_kv_transfer_group().wait_for_save() -+ return get_kv_transfer_group().wait_for_save() + return get_kv_transfer_group().wait_for_save() + def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput"): + if not has_ucm_sparse(): @@ -509,10 +475,10 @@ index 5a26e88db..3cbd79c00 100644 def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 9e7e44d06..28df5ab46 100644 +index 7117f60b5..c239e1f02 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py -@@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput +@@ -30,6 +30,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -520,7 +486,7 @@ index 9e7e44d06..28df5ab46 100644 logger = init_logger(__name__) -@@ -386,6 +387,7 @@ def init_worker_distributed_environment( +@@ -400,6 +401,7 @@ def init_worker_distributed_environment( parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) @@ -528,3 +494,8 @@ index 9e7e44d06..28df5ab46 100644 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + +base-commit: 0362d815b534a9e4a597f1fc5887d950896304b4 +-- +2.50.1.windows.1 + diff --git a/unifiedcache/patch/0.9.2/vllm-adapt.patch b/unifiedcache/patch/0.9.2/vllm-adapt.patch index c5e8ac9e..670ee158 100644 --- a/unifiedcache/patch/0.9.2/vllm-adapt.patch +++ b/unifiedcache/patch/0.9.2/vllm-adapt.patch @@ -1,7 +1,7 @@ From 64a94cbdbc38df6f046379c59ac893545ddbd407 Mon Sep 17 00:00:00 2001 From: flesher0813 <1208954694@qq.com> Date: Sat, 16 Aug 2025 16:57:04 +0800 -Subject: [PATCH] [WIP][v1] Support for returning a value when using +Subject: [PATCH 1/2] [WIP][v1] Support for returning a value when using wait_for_save Signed-off-by: flesher0813 <1208954694@qq.com> @@ -102,3 +102,265 @@ index 5a26e88db..14278bb6a 100644 -- 2.50.1.windows.1 + +From c00b8ca6f917831ad8f14a5d1449a3fd0a1480f5 Mon Sep 17 00:00:00 2001 +From: flesher0813 <1208954694@qq.com> +Date: Sat, 30 Aug 2025 19:13:35 +0800 +Subject: [PATCH 2/2] [BugFix] adapted workers output for dumped blocks + +--- + .../kv_transfer/kv_connector/utils.py | 109 ++++++++++++++++++ + vllm/v1/executor/multiproc_executor.py | 30 ++++- + vllm/v1/worker/gpu_worker.py | 22 +++- + 3 files changed, 153 insertions(+), 8 deletions(-) + +diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py +index 5cbc8ca31..06e71f107 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/utils.py ++++ b/vllm/distributed/kv_transfer/kv_connector/utils.py +@@ -3,12 +3,18 @@ + """ + KV cache helper for store. + """ ++from collections import defaultdict ++from collections.abc import Sequence ++from concurrent.futures import CancelledError, Future ++from typing import Optional, cast ++ + import torch + + import vllm.envs as envs + from vllm import _custom_ops as ops + from vllm.config import VllmConfig, get_current_vllm_config + from vllm.logger import init_logger ++from vllm.v1.outputs import ModelRunnerOutput + + logger = init_logger(__name__) + +@@ -107,3 +113,106 @@ def get_kv_connector_cache_layout(): + "layout to HND for better xfer performance.") + return "HND" + return "NHD" ++ ++ ++class KVOutputAggregator: ++ """Utility class to aggregate the output of all workers into a single ++ output corresponding to Rank 0 for scheduler.""" ++ ++ def __init__(self, world_size: int): ++ # Complete transfer tracker. Used by to track finished requests ++ # [req_id -> n_finished_workers] ++ self._recv_remaining_count = defaultdict[str, int](lambda: world_size) ++ self._send_remaining_count = defaultdict[str, int](lambda: world_size) ++ self._dump_remaining_count = defaultdict[str, int](lambda: world_size) ++ ++ def aggregate(self, ++ outputs: list[ModelRunnerOutput], ++ output_rank: int = 0) -> ModelRunnerOutput: ++ # aggregate finished_sending, finished_recving from all workers ++ ++ def update_finished_set(req_ids: Optional[set[str]], ++ remaining_count_dict: dict[str, int], ++ finished_set: set[str]) -> None: ++ for req_id in req_ids or (): ++ new_count = remaining_count_dict[req_id] - 1 ++ if new_count == 0: ++ finished_set.add(req_id) ++ del remaining_count_dict[req_id] ++ else: ++ remaining_count_dict[req_id] = new_count ++ ++ def update_finished_list(req_ids: Optional[dict[str, list[str]]], ++ remaining_count_dict: dict[str, int], ++ finished_list: dict[str, list[str]]) -> None: ++ for req_id, succeed_dump_blocks in (req_ids or {}).items(): ++ if req_id not in finished_list: ++ finished_list[req_id] = [] ++ for blk_id in succeed_dump_blocks: ++ new_count = remaining_count_dict[blk_id] - 1 ++ if new_count == 0: ++ finished_list[req_id].append(blk_id) ++ del remaining_count_dict[blk_id] ++ else: ++ remaining_count_dict[blk_id] = new_count ++ ++ finished_sending = set[str]() ++ finished_recving = set[str]() ++ finished_dumping: dict[str, list[str]] = {} ++ for output in outputs: ++ update_finished_set(output.finished_sending, ++ self._send_remaining_count, finished_sending) ++ update_finished_set(output.finished_recving, ++ self._recv_remaining_count, finished_recving) ++ update_finished_list(output.finished_dumping, ++ self._dump_remaining_count, finished_dumping) ++ ++ # select output of the worker specified by output_rank ++ output = outputs[output_rank] ++ ++ # set the aggregated finished_sending / finished_recving ++ # if output.finished_sending/recving is not empty, but the other ranks ++ # still have unfinished send/recv, we want to set the aggregated ++ # finished_sending/recving to None until all ranks have finished ++ # send/recv ++ output.finished_sending = finished_sending if finished_sending else None ++ output.finished_recving = finished_recving if finished_recving else None ++ output.finished_dumping = finished_dumping if finished_dumping else None ++ ++ return output ++ ++ def async_aggregate(self, ++ output_futures: Sequence[Future[ModelRunnerOutput]], ++ output_rank: int = 0) -> Future[ModelRunnerOutput]: ++ """Takes a list of futures and returns a single future which resolves ++ to the respective list of outputs.""" ++ result_future: Future[ModelRunnerOutput] = Future() ++ ++ outputs: list[Optional[ModelRunnerOutput]] = [None ++ ] * len(output_futures) ++ ++ def make_callback(idx): ++ ++ def callback(fut): ++ if result_future.done(): ++ return ++ ++ try: ++ outputs[idx] = fut.result() ++ except CancelledError: ++ result_future.cancel() ++ except Exception as e: ++ result_future.set_exception(e) ++ ++ # this check assumes io_thread_pool uses a single thread ++ if all(outputs): ++ result_future.set_result( ++ self.aggregate(cast(list[ModelRunnerOutput], outputs), ++ output_rank)) ++ ++ return callback ++ ++ for i, output_future in enumerate(output_futures): ++ output_future.add_done_callback(make_callback(i)) ++ ++ return result_future +diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py +index b06b7cc80..22c22a148 100644 +--- a/vllm/v1/executor/multiproc_executor.py ++++ b/vllm/v1/executor/multiproc_executor.py +@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) + from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) ++from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator + from vllm.executor.multiproc_worker_utils import ( + _add_prefix, set_multiprocessing_worker_envs) + from vllm.logger import init_logger +@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue ++ # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() ++ self.has_connector = self.vllm_config.kv_transfer_config is not None ++ self.kv_output_aggregator = KVOutputAggregator( ++ self.parallel_config.world_size) + + def start_worker_monitor(self): + workers = self.workers +@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: +- (output, ) = self.collective_rpc( ++ non_block = self.max_concurrent_batches > 1 ++ ++ if not self.has_connector: ++ # get output only from a single worker (output_rank) ++ (output, ) = self.collective_rpc( ++ "execute_model", ++ args=(scheduler_output, ), ++ unique_reply_rank=self.output_rank, ++ non_block=non_block, ++ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) ++ return output ++ ++ # get output from all workers ++ outputs = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), +- unique_reply_rank=self.output_rank, +- non_block=self.max_concurrent_batches > 1, ++ non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) +- return output ++ ++ # aggregate all workers output to a single output ++ if non_block: ++ return self.kv_output_aggregator.async_aggregate( ++ outputs, self.output_rank) ++ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + + def collective_rpc(self, + method: Union[str, Callable], +diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py +index 9e7e44d06..7117f60b5 100644 +--- a/vllm/v1/worker/gpu_worker.py ++++ b/vllm/v1/worker/gpu_worker.py +@@ -1,6 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """A GPU worker class.""" ++import copy + import gc + import os + from typing import TYPE_CHECKING, Optional +@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator + from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +-from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized ++from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, ++ has_kv_transfer_group) + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest +@@ -24,7 +26,7 @@ from vllm.platforms import current_platform + from vllm.sequence import IntermediateTensors + from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +-from vllm.v1.outputs import ModelRunnerOutput ++from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase +@@ -313,9 +315,21 @@ class Worker(WorkerBase): + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) +- return None ++ if not has_kv_transfer_group(): ++ return None ++ ++ # In case of PP with kv transfer, we need to pass through the ++ # finished_sending and finished_recving buffers. ++ new_output = EMPTY_MODEL_RUNNER_OUTPUT ++ if output.finished_sending or output.finished_recving or output.finished_dumping: ++ new_output = copy.copy(new_output) ++ new_output.finished_sending = output.finished_sending ++ new_output.finished_recving = output.finished_recving ++ new_output.finished_dumping = output.finished_dumping ++ output = new_output ++ + assert isinstance(output, ModelRunnerOutput) +- return output if self.is_driver_worker else None ++ return output + + def profile(self, is_start: bool = True): + if self.profiler is None: +-- +2.50.1.windows.1 + diff --git a/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch b/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch index c0266414..6c4ca411 100644 --- a/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch +++ b/unifiedcache/patch/0.9.2/vllm-ascend-adapt.patch @@ -1,7 +1,7 @@ From d5c47a5c2620843cb1af0277ff17768f5e20e057 Mon Sep 17 00:00:00 2001 From: flesher0813 <1208954694@qq.com> Date: Mon, 28 Jul 2025 10:58:23 +0800 -Subject: [PATCH] [Feature]:Add support for the vLLM V1 connector +Subject: [PATCH 1/2] [Feature]:Add support for the vLLM V1 connector Signed-off-by: flesher0813 <1208954694@qq.com> --- @@ -173,3 +173,73 @@ index eabcdbc..f9cca93 100644 -- 2.50.1.windows.1 + +From 0501efb489472b1a08a9447d078f6b9716c8c843 Mon Sep 17 00:00:00 2001 +From: flesher0813 <1208954694@qq.com> +Date: Sat, 30 Aug 2025 19:45:52 +0800 +Subject: [PATCH 2/2] [BugFix] Modify npu worker for aggregating + modelrunner_outputs + +--- + vllm_ascend/worker/worker_v1.py | 23 +++++++++++++++++++---- + 1 file changed, 19 insertions(+), 4 deletions(-) + +diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py +index df03d50..e165506 100644 +--- a/vllm_ascend/worker/worker_v1.py ++++ b/vllm_ascend/worker/worker_v1.py +@@ -17,6 +17,7 @@ + # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py + # + ++import copy + from typing import Optional + + import torch +@@ -27,7 +28,8 @@ from vllm import envs + from vllm.config import VllmConfig + from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +-from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized ++from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, ++ has_kv_transfer_group) + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.logger import logger + from vllm.lora.request import LoRARequest +@@ -35,7 +37,7 @@ from vllm.sequence import IntermediateTensors + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +-from vllm.v1.outputs import ModelRunnerOutput ++from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.worker.worker_base import WorkerBase + + import vllm_ascend.envs as envs_ascend +@@ -222,9 +224,22 @@ class NPUWorker(WorkerBase): + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) +- return None ++ if not has_kv_transfer_group(): ++ return None ++ ++ kv_connector_output = output.kv_connector_output ++ finished_sending = kv_connector_output.finished_sending ++ finished_recving = kv_connector_output.finished_recving ++ ++ if not finished_sending and not finished_recving: ++ return EMPTY_MODEL_RUNNER_OUTPUT ++ ++ new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) ++ new_output.kv_connector_output = kv_connector_output ++ return new_output ++ + assert isinstance(output, ModelRunnerOutput) +- return output if self.is_driver_worker else None ++ return output + + def load_model(self) -> None: + if self.vllm_config.model_config.enable_sleep_mode: +-- +2.50.1.windows.1 + diff --git a/unifiedcache/ucm_sparse/esa.py b/unifiedcache/ucm_sparse/esa.py index cbeabead..344e30d4 100644 --- a/unifiedcache/ucm_sparse/esa.py +++ b/unifiedcache/ucm_sparse/esa.py @@ -378,8 +378,10 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): self.req_states: dict[str, ReqStatePerLayer] = {} self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.block_size = vllm_config.cache_config.block_size config = {"max_cache_size": 5368709120, "device": self.rank, "role": "worker"} self.connector = UcmConnectorFactory.create_connector("UcmDram", config) + # TODO: consider init self.is_mla here def attention_begin( self, @@ -445,16 +447,17 @@ def build_sparse_meta( num_scheduled_tokens, ) in scheduler_output.num_scheduled_tokens.items(): req_state = requests[req_id] - sparse_meta.add_request( - req_id, - input_batch.req_id_to_index[req_id], - len(req_state.prompt_token_ids), - len(req_state.output_token_ids), - num_scheduled_tokens, - req_state.num_computed_tokens, - scheduler_output.req_sparsed_slots[req_id], - req_state.block_ids[0], - ) + if len(req_state.prompt_token_ids) > self.block_size: + sparse_meta.add_request( + req_id, + input_batch.req_id_to_index[req_id], + len(req_state.prompt_token_ids), + len(req_state.output_token_ids), + num_scheduled_tokens, + req_state.num_computed_tokens, + scheduler_output.req_sparsed_slots[req_id], + req_state.block_ids[0], + ) self._sparse_metadata = sparse_meta def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): @@ -470,17 +473,19 @@ def update_state_after_alloc(self, request: Request, num_blocks: int): pass def estimate_num_slots_sparsed(self, request: Request) -> int: - if request.num_output_tokens == 0: + if ( + request.num_output_tokens == 0 + or request.num_prompt_tokens < self.block_size + ): return INVALID_SLOT - block_size = self._vllm_config.cache_config.block_size - num_blocks = math.ceil(request.num_tokens / block_size) + num_blocks = math.ceil(request.num_tokens / self.block_size) mid_window_sz = int( (num_blocks - INIT_WINDOW_SZ - LOCAL_WINDOW_SZ) * SPARSE_RATIO ) - flaw = request.num_tokens % block_size + flaw = request.num_tokens % self.block_size if flaw: - flaw = block_size - flaw + flaw = self.block_size - flaw num_tokens_sparsed = ( INIT_WINDOW_SZ + mid_window_sz + LOCAL_WINDOW_SZ - ) * block_size - flaw + ) * self.block_size - flaw return num_tokens_sparsed