diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 04ab80e0..426b7b23 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,8 +1,11 @@ import contextlib +import json import os import time from dataclasses import asdict +from transformers import AutoTokenizer + # Third Party from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig @@ -10,6 +13,8 @@ from ucm.logger import init_logger +MODEL_PATH = "/home/models/Qwen2.5-14B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) logger = init_logger(__name__) @@ -25,21 +30,30 @@ def build_llm_with_uc(module_path: str, name: str, model: str): kv_connector_module_path=module_path, kv_role="kv_both", kv_connector_extra_config={ - "ucm_connector_name": "UcmDramStore", + "ucm_connector_name": "UcmNfsStore", "ucm_connector_config": { - "max_cache_size": 53687091200, - "kv_block_size": 262144, + "storage_backends": "/home/data", + "kv_block_size": 33554432, + }, + "ucm_sparse_config": { + "ESA": { + "init_window_sz": 1, + "local_window_sz": 2, + "min_blocks": 4, + "sparse_ratio": 0.3, + "retrieval_stride": 5, + } }, - "ucm_sparse_method": "GSA", }, ) llm_args = EngineArgs( model=model, - kv_transfer_config=ktc, - max_model_len=40960, - gpu_memory_utilization=0.87, + max_model_len=32768, + gpu_memory_utilization=0.8, + max_num_batched_tokens=30000, block_size=128, + enforce_eager=True, ) llm = LLM(**asdict(llm_args)) @@ -72,17 +86,35 @@ def main(): setup_environment_variables() - with build_llm_with_uc(module_path, name, model) as llm: - prompts = [ - "Imagine you are an artificial intelligence developed in the year 2075, designed to assist humanity in " - "navigating the complex ethical, philosophical, and technological challenges of a rapidly evolving world. " - "You have access to vast historical records, scientific data, and human literature, and your core " - "directive is to promote sustainable development, social equity, and the flourishing of conscious beings. " - "Write a detailed letter to the leaders of Earth, explaining the most urgent global issue of the 21st " - "century, the root sauses behind it, and a set of scientifically grounded, morally sound, and globally " - "cooperative solutions that transcend culturak and national boundaries. Include both immediate actions " - "and long-term strategies." * 200 + def get_prompt(prompt): + messages = [ + { + "role": "system", + "content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n", + }, + {"role": "user", "content": prompt}, ] + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=True, + ) + + with build_llm_with_uc(module_path, name, model) as llm: + prompts = [] + + batch_size = 1 + + with open("/home/datasets/Longbench/data/multifieldqa_zh.jsonl", "r") as f: + for _ in range(batch_size): + line = f.readline() + if not line: + break + data = json.loads(line) + context = data["context"] + question = data["input"] + prompts.append(get_prompt(f"{context}\n\n{question}")) sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100) diff --git a/setup.py b/setup.py index a8edbdc5..1829efdf 100644 --- a/setup.py +++ b/setup.py @@ -35,9 +35,11 @@ STORE_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "store") GSA_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "gsaoffloadops") PREFETCH_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "ucmprefetch") +RETRIEVAL_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "esaretrieval") STORE_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "store", "connector") GSA_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "ucm_sparse") +RETRIEVAL_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "ucm_sparse", "retrieval") PLATFORM = os.getenv("PLATFORM") @@ -89,7 +91,7 @@ def build_cmake(self, ext: CMakeExtension): subprocess.check_call(cmake_args, cwd=build_dir) - if ext.name in ["store", "gsa_offload_ops"]: + if ext.name in ["store", "gsa_offload_ops", "esaretrieval"]: subprocess.check_call(["make", "-j", "8"], cwd=build_dir) else: # 对于gsa_prefetch使用cmake --build @@ -115,6 +117,8 @@ def _copy_so_files(self, ext: CMakeExtension): search_patterns.extend(["gsa_offload_ops"]) elif ext.name == "gsa_prefetch": search_patterns.extend(["prefetch"]) + elif ext.name == "esaretrieval": + search_patterns.extend(["retrieval_backend"]) for file in os.listdir(so_search_dir): if file.endswith(".so") or ".so." in file: @@ -124,8 +128,11 @@ def _copy_so_files(self, ext: CMakeExtension): break if ext.name == "store": - install_dir = STORE_INSTALL_DIR - build_install_dir = STORE_INSTALL_DIR + install_dir = FSSTORE_INSTALL_DIR + build_install_dir = "ucm/store" + elif ext.name == "esaretrieval": + install_dir = RETRIEVAL_INSTALL_DIR + build_install_dir = "ucm/ucm_sparse/retrieval" else: install_dir = GSA_INSTALL_DIR build_install_dir = "ucm/ucm_sparse" @@ -134,7 +141,6 @@ def _copy_so_files(self, ext: CMakeExtension): src_path = os.path.join(so_search_dir, so_file) dev_path = os.path.join(install_dir, so_file) dst_path = os.path.join(self.build_lib, build_install_dir, so_file) - os.makedirs(os.path.dirname(dst_path), exist_ok=True) shutil.copy(src_path, dst_path) print(f"[INFO] Copied {so_file} → {dst_path}") @@ -149,6 +155,7 @@ def _copy_so_files(self, ext: CMakeExtension): ext_modules.append(CMakeExtension(name="store", sourcedir=STORE_SRC_DIR)) ext_modules.append(CMakeExtension(name="gsa_offload_ops", sourcedir=GSA_SRC_DIR)) ext_modules.append(CMakeExtension(name="gsa_prefetch", sourcedir=PREFETCH_SRC_DIR)) +ext_modules.append(CMakeExtension(name="esaretrieval", sourcedir=RETRIEVAL_SRC_DIR)) setup( name="ucm", diff --git a/ucm/csrc/esaretrieval/CMakeLists.txt b/ucm/csrc/esaretrieval/CMakeLists.txt new file mode 100644 index 00000000..546d59c1 --- /dev/null +++ b/ucm/csrc/esaretrieval/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.14) +project(retrieval_backend LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + +include(FetchContent) +FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.13.6 + GIT_SHALLOW TRUE +) +FetchContent_MakeAvailable(pybind11) + +pybind11_add_module(retrieval_backend + retrieval_backend.cpp +) + +set(OUTPUT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/output) +set_target_properties(retrieval_backend PROPERTIES + PREFIX "" + SUFFIX ".so" + LIBRARY_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/lib" + RUNTIME_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/bin" + ARCHIVE_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/lib" +) + +target_compile_options(retrieval_backend PRIVATE -O3 -Wall -fPIC) +target_link_libraries(retrieval_backend PRIVATE Python3::Python) diff --git a/ucm/csrc/esaretrieval/retrieval_backend.cpp b/ucm/csrc/esaretrieval/retrieval_backend.cpp new file mode 100644 index 00000000..7b1c212b --- /dev/null +++ b/ucm/csrc/esaretrieval/retrieval_backend.cpp @@ -0,0 +1,225 @@ +// retrieval_backend.cpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +class RetrievalWorkerBackend { +public: + RetrievalWorkerBackend(py::array_t data) + : data_array_(data), stop_workers_(false), next_req_id_(0) + { + py::buffer_info info = data_array_.request(); + n_items_ = info.shape[0]; + dim_ = info.shape[1]; + data_ = static_cast(info.ptr); + + // Start worker threads + int n_workers = std::thread::hardware_concurrency(); + for (int i = 0; i < n_workers; ++i) { + worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this); + } + } + + ~RetrievalWorkerBackend() { + { + std::lock_guard lock(mutex_); + stop_workers_ = true; + cond_.notify_all(); + } + for (auto& t: worker_threads_) t.join(); + } + + int submit(py::array_t query, int topk, py::array_t indexes) { + py::buffer_info qinfo = query.request(); + py::buffer_info iinfo = indexes.request(); + if (qinfo.shape[1] != dim_) + throw std::runtime_error("Query dim mismatch"); + if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) + throw std::runtime_error("Query and indexes batch mismatch"); + + int req_id = next_req_id_.fetch_add(1); + + auto q = std::vector((float*)qinfo.ptr, (float*)qinfo.ptr + qinfo.shape[0] * dim_); + + // Parse indexes to vector> + size_t n_requests = iinfo.shape[0], max_index_number = iinfo.shape[1]; + const int* idx_ptr = static_cast(iinfo.ptr); + std::vector> idxvec(n_requests); + for (size_t i = 0; i < n_requests; ++i) { + for (size_t j = 0; j < max_index_number; ++j) { + int index = idx_ptr[i * max_index_number + j]; + if (index != -1) idxvec[i].push_back(index); + } + } + + auto status = std::make_shared(); + { + std::lock_guard lock(mutex_); + requests_.emplace(Request{req_id, std::move(q), n_requests, topk, std::move(idxvec)}); + request_status_[req_id] = status; + } + cond_.notify_one(); + return req_id; + } + + bool poll(int req_id) { + std::lock_guard lock(mutex_); + return results_.find(req_id) != results_.end(); + } + + void wait(int req_id) { + std::shared_ptr s; + { + std::lock_guard lock(mutex_); + auto it = request_status_.find(req_id); + if (it == request_status_.end()) throw std::runtime_error("Bad req_id"); + s = it->second; + } + std::unique_lock lk2(s->m); + s->cv.wait(lk2, [&] { return s->done; }); + } + + py::dict get_result(int req_id) { + std::lock_guard lock(mutex_); + auto it = results_.find(req_id); + if (it == results_.end()) throw std::runtime_error("Result not ready"); + + size_t batch_size = it->second.indices.size(); + size_t topk = batch_size > 0 ? it->second.indices[0].size() : 0; + py::array_t indices({batch_size, topk}); + py::array_t scores({batch_size, topk}); + + auto indices_ptr = static_cast(indices.request().ptr); + auto scores_ptr = static_cast(scores.request().ptr); + + for (size_t i = 0; i < batch_size; ++i) { + memcpy(indices_ptr + i * topk, it->second.indices[i].data(), topk * sizeof(int)); + memcpy(scores_ptr + i * topk, it->second.scores[i].data(), topk * sizeof(float)); + } + py::dict result; + result["indices"] = indices; + result["scores"] = scores; + results_.erase(it); + return result; + } + +private: + struct Request { + int req_id; + std::vector query; // Flattened [batch, dim] + size_t batch; + int topk; + std::vector> indexes; // Per-request index subset + }; + struct Result { + std::vector> indices; + std::vector> scores; + }; + struct RequestStatus { + std::mutex m; + std::condition_variable cv; + bool done = false; + }; + + void worker_loop() { + while (true) { + Request req; + { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&]{ return stop_workers_ || !requests_.empty(); }); + if (stop_workers_ && requests_.empty()) return; + req = std::move(requests_.front()); + requests_.pop(); + } + + Result res; + res.indices.resize(req.batch); + res.scores.resize(req.batch); + + // for performance + // std::mt19937 gen(42); + // for (size_t b = 0; b < req.batch; ++b) { + // const float* q_ptr = req.query.data() + b * dim_; + // const auto& allowed = req.indexes[b]; + + // std::vector index; + // int i = 0; + // for (auto &c: allowed) { + // index.push_back(i++); + // } + // std::shuffle(index.begin(), index.end(), gen); + // int curr_topk = std::min(static_cast(allowed.size()), req.topk); + // for (int k = 0; k < curr_topk; ++k) { + // res.indices[b].push_back(allowed[index[k]]); + // res.scores[b].push_back(0.0f); // Dummy/fixed score + // } + // } + + // for precision + for (size_t b = 0; b < req.batch; ++b) { + const float* q_ptr = req.query.data() + b * dim_; + const auto& allowed = req.indexes[b]; + std::vector> heap; + heap.reserve(allowed.size()); + for (auto idx : allowed) { + float score = 0.0f; + for (size_t d = 0; d < dim_; ++d) { + score += q_ptr[d] * data_[idx * dim_ + d]; + } + heap.emplace_back(score, idx); + } + int curr_topk = std::min((int)heap.size(), req.topk); + std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), + [](const auto& a, const auto& b){ return a.first > b.first; }); + + for (int k = 0; k < curr_topk; ++k) { + res.scores[b].push_back(heap[k].first); + res.indices[b].push_back(heap[k].second); + } + } + + { + std::lock_guard lock(mutex_); + results_[req.req_id] = std::move(res); + auto s = request_status_[req.req_id]; + { + std::lock_guard lk2(s->m); + s->done = true; + } + s->cv.notify_all(); + } + } + } + + py::array_t data_array_; + const float* data_ = nullptr; + size_t n_items_, dim_; + std::queue requests_; + std::unordered_map results_; + std::vector worker_threads_; + std::mutex mutex_; + std::condition_variable cond_; + std::atomic next_req_id_; + std::unordered_map> request_status_; + bool stop_workers_; +}; + +PYBIND11_MODULE(retrieval_backend, m) { + py::class_(m, "RetrievalWorkerBackend") + .def(py::init>()) + .def("submit", &RetrievalWorkerBackend::submit) + .def("poll", &RetrievalWorkerBackend::poll) + .def("get_result", &RetrievalWorkerBackend::get_result) + .def("wait", &RetrievalWorkerBackend::wait); +} \ No newline at end of file diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch deleted file mode 100644 index 69b06365..00000000 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt-sparse.patch +++ /dev/null @@ -1,499 +0,0 @@ -From 732b37e6baf36ec47c992083bb9a810668e900fd Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Sat, 6 Sep 2025 23:01:59 +0800 -Subject: [PATCH] [Patch] Support ucm sparse - -Signed-off-by: flesher0813 <1208954694@qq.com> ---- - vllm/attention/layer.py | 44 ++++++++++++++++++- - vllm/v1/core/kv_cache_manager.py | 32 +++++++++++++- - vllm/v1/core/sched/output.py | 3 ++ - vllm/v1/core/sched/scheduler.py | 26 ++++++++++- - vllm/v1/worker/block_table.py | 13 ++++++ - vllm/v1/worker/gpu_model_runner.py | 70 +++++++++++++++++++++++++----- - vllm/v1/worker/gpu_worker.py | 2 + - 7 files changed, 175 insertions(+), 15 deletions(-) - -diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..cec242372 100644 ---- a/vllm/attention/layer.py -+++ b/vllm/attention/layer.py -@@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod - from vllm.platforms import _Backend, current_platform - from vllm.utils import direct_register_custom_op - from vllm.v1.attention.backends.utils import validate_kv_sharing_target -+from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse - - - class Attention(nn.Module): -@@ -409,9 +410,10 @@ def unified_attention( - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - output = self.impl.forward(self, query, key, value, kv_cache, - attn_metadata) -- -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return output - -@@ -449,6 +451,7 @@ def unified_attention_with_output( - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - self.impl.forward(self, - query, - key, -@@ -457,7 +460,7 @@ def unified_attention_with_output( - attn_metadata, - output=output, - output_scale=output_scale) -- -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - - -@@ -479,3 +482,40 @@ direct_register_custom_op( - fake_impl=unified_attention_with_output_fake, - dispatch_key=current_platform.dispatch_key, - ) -+ -+def maybe_execute_sparse_attention_begin( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) -+ -+def maybe_execute_sparse_attention_finished( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ attn_output: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) -diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py -index 6937455e7..d9bf943ab 100644 ---- a/vllm/v1/core/kv_cache_manager.py -+++ b/vllm/v1/core/kv_cache_manager.py -@@ -1,9 +1,10 @@ - # SPDX-License-Identifier: Apache-2.0 - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -+import math - from collections import defaultdict - from dataclasses import dataclass --from typing import Optional -+from typing import Optional, Union - - from vllm.distributed.kv_events import KVCacheEvent - from vllm.logger import init_logger -@@ -14,6 +15,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - from vllm.v1.kv_cache_interface import KVCacheConfig - from vllm.v1.metrics.stats import PrefixCacheStats - from vllm.v1.request import Request, RequestStatus -+from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse -+from ucm.integration.vllm.ucm_sparse.base import INVALID_SLOT - - logger = init_logger(__name__) - -@@ -193,6 +196,7 @@ class KVCacheManager: - num_draft_tokens: int = 0, - num_lookahead_tokens: int = 0, - delay_cache_blocks: bool = False, -+ num_slots_sparsed: Union[None, int] = None - ) -> Optional[KVCacheBlocks]: - """Add slots for a request with new tokens to append. - -@@ -231,6 +235,32 @@ class KVCacheManager: - """ - if num_new_tokens == 0: - raise ValueError("num_new_tokens must be greater than 0") -+ -+ if num_slots_sparsed != INVALID_SLOT: -+ self.block_size = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size -+ num_blocks_need = math.ceil(num_slots_sparsed / self.block_size) -+ allocated_blocks = self.coordinator.get_blocks(request.request_id)[0] -+ returned_blocks = [] -+ sparsed_blocks = [] -+ for i, block in enumerate(allocated_blocks): -+ if i < num_blocks_need: -+ sparsed_blocks.append(block) -+ else: -+ returned_blocks.append(block) -+ self.block_pool._maybe_evict_cached_block(block) -+ self.block_pool.free_blocks(returned_blocks) -+ self.coordinator.single_type_managers[0].req_to_blocks[request.request_id] = sparsed_blocks -+ new_computed_block_list = tuple( -+ [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) -+ num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate( -+ request_id=request.request_id, -+ num_tokens=num_slots_sparsed, -+ new_computed_blocks=new_computed_block_list, -+ ) -+ if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): -+ return None -+ new_blocks = self.coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) -+ return KVCacheBlocks(tuple([sparsed_blocks])) - - if new_computed_blocks is not None: - new_computed_block_list = new_computed_blocks.blocks -diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..a0ab878a5 100644 ---- a/vllm/v1/core/sched/output.py -+++ b/vllm/v1/core/sched/output.py -@@ -155,3 +155,6 @@ class SchedulerOutput: - - # KV Cache Connector metadata. - kv_connector_metadata: Optional[KVConnectorMetadata] = None -+ -+ # modified slots by sparse algorithm -+ req_sparsed_slots: dict[str, int] = None -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index b16e38423..211699147 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput - from vllm.v1.request import Request, RequestStatus - from vllm.v1.spec_decode.metrics import SpecDecodingStats - from vllm.v1.structured_output import StructuredOutputManager -+from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse -+from ucm.integration.vllm.ucm_sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT - - logger = init_logger(__name__) - -@@ -79,12 +81,18 @@ class Scheduler(SchedulerInterface): - # will have a corresponding KVConnector with Role=WORKER. - # KV Connector pushes/pull of remote KVs for P/D and offloading. - self.connector = None -+ self.ucm_sparse = None - if self.vllm_config.kv_transfer_config is not None: - assert len(self.kv_cache_config.kv_cache_groups) == 1, ( - "Multiple KV cache groups are not currently supported " - "with KV connectors") - self.connector = KVConnectorFactory.create_connector_v1( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) -+ # Initialize UCM Sparse if available -+ if "ucm_sparse_method" in vllm_config.kv_transfer_config.kv_connector_extra_config: -+ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) -+ self.ucm_sparse = get_ucm_sparse() -+ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) - - self.kv_event_publisher = EventPublisherFactory.create( - self.kv_events_config, -@@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface): - - # First, schedule the RUNNING requests. - req_index = 0 -+ req_sparsed_slots: dict[str, int] = {} - while req_index < len(self.running) and token_budget > 0: - request = self.running[req_index] -+ num_slots_sparsed = INVALID_SLOT -+ if self.ucm_sparse: -+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) -+ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - - num_new_tokens = (request.num_tokens_with_spec - - request.num_computed_tokens) -@@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface): - request, - num_new_tokens, - num_draft_tokens=num_draft_tokens, -- num_lookahead_tokens=self.num_lookahead_tokens) -+ num_lookahead_tokens=self.num_lookahead_tokens, -+ num_slots_sparsed=num_slots_sparsed) - if new_blocks is None: - # The request cannot be scheduled. - # Preempt the lowest-priority request. -@@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface): - break - - request = self.waiting.peek_request() -+ num_slots_sparsed = INVALID_SLOT -+ if self.ucm_sparse: -+ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) -+ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) - - # KVTransfer: skip request if still waiting for remote kvs. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -@@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface): - new_computed_blocks, - num_lookahead_tokens=self.num_lookahead_tokens, - delay_cache_blocks=load_kv_async, -+ num_slots_sparsed=num_slots_sparsed - ) - if new_blocks is None: - # The request cannot be scheduled. -@@ -559,6 +578,7 @@ class Scheduler(SchedulerInterface): - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, - num_common_prefix_blocks=num_common_prefix_blocks, -+ req_sparsed_slots=req_sparsed_slots, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between -@@ -941,6 +961,8 @@ class Scheduler(SchedulerInterface): - def add_request(self, request: Request) -> None: - self.waiting.add_request(request) - self.requests[request.request_id] = request -+ if self.ucm_sparse: -+ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) - if self.log_stats: - request.record_event(EngineCoreEventType.QUEUED) - -@@ -990,6 +1012,8 @@ class Scheduler(SchedulerInterface): - - def _free_request(self, request: Request) -> Optional[dict[str, Any]]: - assert request.is_finished() -+ if self.ucm_sparse: -+ self.ucm_sparse.request_finished_in_scheduler(request.request_id) - - delay_free_blocks, kv_xfer_params = self._connector_finished(request) - self.encoder_cache_manager.free(request) -diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py -index 8f4e8d64c..733ac1f41 100644 ---- a/vllm/v1/worker/block_table.py -+++ b/vllm/v1/worker/block_table.py -@@ -60,6 +60,15 @@ class BlockTable: - start = self.num_blocks_per_row[row_idx] - self.num_blocks_per_row[row_idx] += num_blocks - self.block_table_np[row_idx, start:start + num_blocks] = block_ids -+ -+ def reset_row( -+ self, -+ row_idx: int, -+ ) -> None: -+ self.num_blocks_per_row[row_idx] = 0 -+ self.block_table[row_idx].fill_(0) -+ self.block_table_cpu[row_idx].fill_(0) -+ self.block_table_np[row_idx].fill(0) - - def add_row(self, block_ids: list[int], row_idx: int) -> None: - self.num_blocks_per_row[row_idx] = 0 -@@ -116,6 +125,10 @@ class MultiGroupBlockTable: - row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): - block_table.append_row(block_ids[i], row_idx) -+ -+ def reset_row(self, row_idx: int) -> None: -+ for i, block_table in enumerate(self.block_tables): -+ block_table.reset_row(row_idx) - - def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: - for i, block_table in enumerate(self.block_tables): -diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 9471f7512..f9ad56f5c 100644 ---- a/vllm/v1/worker/gpu_model_runner.py -+++ b/vllm/v1/worker/gpu_model_runner.py -@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager - from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, - sanity_check_mm_encoder_outputs, scatter_mm_placeholders) - -+from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse -+from ucm.integration.vllm.ucm_sparse.base import UcmSparseMetadata, INVALID_SLOT -+ - if TYPE_CHECKING: - import xgrammar as xgr - import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 -@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: -+ self.ucm_sparse_request_finished_in_worker(req_id) - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - # Remove the finished requests from the persistent batch. -@@ -468,11 +472,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # Update the states of the running/resumed requests. - is_last_rank = get_pp_group().is_last_rank - req_data = scheduler_output.scheduled_cached_reqs -+ req_sparsed_slots = scheduler_output.req_sparsed_slots - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - - # Update the cached states. - if (num_computed_tokens <= req_state.num_computed_tokens): -@@ -518,15 +524,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): - req_state.generator.get_offset()) - - # Update the block IDs. -- if not resumed_from_preemption: -- # Append the new blocks to the existing block IDs. -- for block_ids, new_ids in zip(req_state.block_ids, -- new_block_ids): -- block_ids.extend(new_ids) -- else: -+ if resumed_from_preemption or is_sparsed_request: - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids -+ else: -+ # Append the new blocks to the existing block IDs. -+ for block_ids, new_ids in zip(req_state.block_ids, -+ new_block_ids): -+ block_ids.extend(new_ids) - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: -@@ -544,6 +550,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) -+ if is_sparsed_request: -+ self.input_batch.block_table.reset_row(req_index) - self.input_batch.block_table.append_row(new_block_ids, req_index) - - # For the last rank, we don't need to update the token_ids_cpu -@@ -651,6 +659,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) -+ -+ self.seq_lens_np[:num_reqs] = ( -+ self.input_batch.num_computed_tokens_cpu[:num_reqs] + -+ num_scheduled_tokens) -+ -+ # TODO: improve performance, no `positions_np.copy()` -+ sparsed_positions = positions_np.copy() -+ req_sparsed_slots = scheduler_output.req_sparsed_slots -+ for req_id in self.input_batch.req_id_to_index: -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT -+ req_index = self.input_batch.req_id_to_index[req_id] -+ if is_sparsed_request: -+ sparsed_positions[req_index] -= self.seq_lens_cpu[:num_reqs][req_index] - req_sparsed_slots[req_id] -+ - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] -@@ -681,11 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + -- positions_np // block_size) -+ sparsed_positions // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() -- block_offsets = positions_np % block_size -+ block_offsets = sparsed_positions % block_size - np.add( - block_numbers * block_size, - block_offsets, -@@ -695,9 +717,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - -- self.seq_lens_np[:num_reqs] = ( -- self.input_batch.num_computed_tokens_cpu[:num_reqs] + -- num_scheduled_tokens) -+ for req_id in self.input_batch.req_id_to_index: -+ req_index = self.input_batch.req_id_to_index[req_id] -+ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT -+ if is_sparsed_request: -+ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id] - - # Copy the tensors to the GPU. - self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -709,6 +733,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - non_blocking=True) - else: - # Common case (1D positions) -+ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( -+ sparsed_positions[:total_num_scheduled_tokens]) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) -@@ -1399,6 +1425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - skip_cuda_graphs=skip_cuda_graphs, - ): - self.maybe_setup_kv_connector(scheduler_output) -+ self.maybe_execute_ucm_sparse_begin(scheduler_output) - - model_output = self.model( - input_ids=input_ids, -@@ -1408,6 +1435,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - ) - - finished_dumping = self.maybe_wait_for_kv_save() -+ self.maybe_execute_ucm_sparse_finished() -+ - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - invalid_block_ids = self.get_block_ids_with_load_errors() -@@ -1758,6 +1787,25 @@ class GPUModelRunner(LoRAModelRunnerMixin): - def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: - if has_kv_transfer_group(): - return get_kv_transfer_group().wait_for_save() -+ -+ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput"): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch) -+ ucm_sparse.execute_begin(scheduler_output) -+ -+ def maybe_execute_ucm_sparse_finished(self): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.execute_finished() -+ -+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.request_finished_in_worker(request_id) - - @staticmethod - def get_finished_kv_transfers( -diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 1b816b25b..d52a49a2e 100644 ---- a/vllm/v1/worker/gpu_worker.py -+++ b/vllm/v1/worker/gpu_worker.py -@@ -30,6 +30,7 @@ from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput - from vllm.v1.utils import report_usage_stats - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - from vllm.v1.worker.worker_base import WorkerBase -+from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized - - logger = init_logger(__name__) - -@@ -401,6 +402,7 @@ def init_worker_distributed_environment( - parallel_config.pipeline_parallel_size) - - ensure_kv_transfer_initialized(vllm_config) -+ ensure_ucm_sparse_initialized(vllm_config) - - - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): --- -2.50.1.windows.1 - diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch index 2023e5f7..84f5e4c8 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch @@ -1,121 +1,119 @@ -From 64a94cbdbc38df6f046379c59ac893545ddbd407 Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Sat, 16 Aug 2025 16:57:04 +0800 -Subject: [PATCH 1/3] [WIP][v1] Support for returning a value when using - wait_for_save +From 9124f6f48b958f2535702d8093495097257a2ccc Mon Sep 17 00:00:00 2001 +From: wenxinwang +Date: Thu, 25 Sep 2025 05:03:42 -0700 +Subject: [PATCH] UCM adaptor -Signed-off-by: flesher0813 <1208954694@qq.com> --- - vllm/v1/core/sched/scheduler.py | 4 +++- - vllm/v1/outputs.py | 1 + - vllm/v1/request.py | 2 +- - vllm/v1/worker/gpu_model_runner.py | 7 ++++--- - 4 files changed, 9 insertions(+), 5 deletions(-) + vllm/attention/layer.py | 45 ++++- + .../kv_transfer/kv_connector/utils.py | 113 ++++++++++++ + .../kv_transfer/kv_connector/v1/base.py | 11 +- + .../v1/shared_storage_connector.py | 7 +- + vllm/v1/core/block_pool.py | 2 +- + vllm/v1/core/kv_cache_manager.py | 11 +- + vllm/v1/core/sched/output.py | 3 + + vllm/v1/core/sched/scheduler.py | 165 +++++++++++++++++- + vllm/v1/core/single_type_kv_cache_manager.py | 3 + + vllm/v1/executor/multiproc_executor.py | 30 +++- + vllm/v1/outputs.py | 5 + + vllm/v1/request.py | 2 +- + vllm/v1/worker/block_table.py | 13 ++ + vllm/v1/worker/gpu_input_batch.py | 9 + + vllm/v1/worker/gpu_model_runner.py | 122 +++++++++++-- + vllm/v1/worker/gpu_worker.py | 25 ++- + 16 files changed, 526 insertions(+), 40 deletions(-) -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..22c0ad8d6 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -792,6 +792,8 @@ class Scheduler(SchedulerInterface): - new_token_ids = generated_token_ids - kv_transfer_params = None - -+ if model_runner_output.finished_dumping is not None: -+ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. -@@ -842,7 +844,6 @@ class Scheduler(SchedulerInterface): - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] -- - # Get prompt logprobs for this request. - prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ -@@ -869,6 +870,7 @@ class Scheduler(SchedulerInterface): - - if not stopped: - new_running.append(request) -+ - self.running = new_running - - # KV Connector: update state for finished KV Transfers. -diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..c8388baed 100644 ---- a/vllm/v1/outputs.py -+++ b/vllm/v1/outputs.py -@@ -107,6 +107,7 @@ class ModelRunnerOutput: - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None -+ finished_dumping: Optional[dict[str, list[str]]] = None - - # req_id -> num_nans_in_logits - num_nans_in_logits: Optional[dict[str, int]] = None -diff --git a/vllm/v1/request.py b/vllm/v1/request.py -index 9b96f4599..825b77bba 100644 ---- a/vllm/v1/request.py -+++ b/vllm/v1/request.py -@@ -102,7 +102,7 @@ class Request: - # State - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 +diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py +index f0ad68b16..89b3da489 100644 +--- a/vllm/attention/layer.py ++++ b/vllm/attention/layer.py +@@ -2,7 +2,6 @@ + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """Attention layer.""" + from typing import Any, Dict, List, Optional - -+ self.succeed_dumped_blocks: list[str] = [] - # The number of NaNs in logits. A value greater than 0 - # indicates that the output is corrupted - self.num_nans_in_logits = 0 -diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..14278bb6a 100644 ---- a/vllm/v1/worker/gpu_model_runner.py -+++ b/vllm/v1/worker/gpu_model_runner.py -@@ -1378,7 +1378,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - inputs_embeds=inputs_embeds, - ) - -- self.maybe_wait_for_kv_save() -+ finished_dumping = self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - -@@ -1563,6 +1563,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_sending=finished_sending, - finished_recving=finished_recving, - num_nans_in_logits=num_nans_in_logits, -+ finished_dumping=finished_dumping - ) - - def propose_draft_token_ids( -@@ -1719,9 +1720,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod -- def maybe_wait_for_kv_save() -> None: -+ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: - if has_kv_transfer_group(): -- get_kv_transfer_group().wait_for_save() -+ return get_kv_transfer_group().wait_for_save() - - @staticmethod - def get_finished_kv_transfers( --- -2.50.1.windows.1 + import torch + import torch.nn as nn + import torch.nn.functional as F +@@ -22,6 +21,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + from vllm.platforms import _Backend, current_platform + from vllm.utils import direct_register_custom_op + from vllm.v1.attention.backends.utils import validate_kv_sharing_target ++from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse -From c00b8ca6f917831ad8f14a5d1449a3fd0a1480f5 Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Sat, 30 Aug 2025 19:13:35 +0800 -Subject: [PATCH 2/3] [BugFix] adapted workers output for dumped blocks + class Attention(nn.Module): +@@ -409,9 +409,10 @@ def unified_attention( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + output = self.impl.forward(self, query, key, value, kv_cache, + attn_metadata) +- ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + return output ---- - .../kv_transfer/kv_connector/utils.py | 109 ++++++++++++++++++ - vllm/v1/executor/multiproc_executor.py | 30 ++++- - vllm/v1/worker/gpu_worker.py | 22 +++- - 3 files changed, 153 insertions(+), 8 deletions(-) +@@ -449,6 +450,7 @@ def unified_attention_with_output( + attn_metadata = attn_metadata[layer_name] + self = forward_context.no_compile_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + self.impl.forward(self, + query, + key, +@@ -457,7 +459,7 @@ def unified_attention_with_output( + attn_metadata, + output=output, + output_scale=output_scale) +- ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) + +@@ -479,3 +481,40 @@ direct_register_custom_op( + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, + ) ++ ++def maybe_execute_sparse_attention_begin( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) ++ ++def maybe_execute_sparse_attention_finished( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ attn_output: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index 5cbc8ca31..06e71f107 100644 +index 5cbc8ca31..0fee7e74c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,12 +3,18 @@ @@ -128,23 +126,23 @@ index 5cbc8ca31..06e71f107 100644 +from typing import Optional, cast + import torch - + import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput - + logger = init_logger(__name__) - -@@ -107,3 +113,106 @@ def get_kv_connector_cache_layout(): + +@@ -107,3 +113,110 @@ def get_kv_connector_cache_layout(): "layout to HND for better xfer performance.") return "HND" return "NHD" + + +class KVOutputAggregator: -+ """Utility class to aggregate the output of all workers into a single ++ """Utility class to aggregate the output of all workers into a single + output corresponding to Rank 0 for scheduler.""" + + def __init__(self, world_size: int): @@ -169,7 +167,7 @@ index 5cbc8ca31..06e71f107 100644 + del remaining_count_dict[req_id] + else: + remaining_count_dict[req_id] = new_count -+ ++ + def update_finished_list(req_ids: Optional[dict[str, list[str]]], + remaining_count_dict: dict[str, int], + finished_list: dict[str, list[str]]) -> None: @@ -186,6 +184,7 @@ index 5cbc8ca31..06e71f107 100644 + + finished_sending = set[str]() + finished_recving = set[str]() ++ invalid_block_ids = set[int]() + finished_dumping: dict[str, list[str]] = {} + for output in outputs: + update_finished_set(output.finished_sending, @@ -194,6 +193,8 @@ index 5cbc8ca31..06e71f107 100644 + self._recv_remaining_count, finished_recving) + update_finished_list(output.finished_dumping, + self._dump_remaining_count, finished_dumping) ++ if output.invalid_block_ids: ++ invalid_block_ids |= output.invalid_block_ids + + # select output of the worker specified by output_rank + output = outputs[output_rank] @@ -206,6 +207,7 @@ index 5cbc8ca31..06e71f107 100644 + output.finished_sending = finished_sending if finished_sending else None + output.finished_recving = finished_recving if finished_recving else None + output.finished_dumping = finished_dumping if finished_dumping else None ++ output.invalid_block_ids = invalid_block_ids or None + + return output + @@ -244,185 +246,16 @@ index 5cbc8ca31..06e71f107 100644 + output_future.add_done_callback(make_callback(i)) + + return result_future -diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py -index b06b7cc80..22c22a148 100644 ---- a/vllm/v1/executor/multiproc_executor.py -+++ b/vllm/v1/executor/multiproc_executor.py -@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) - from vllm.distributed.device_communicators.shm_broadcast import (Handle, - MessageQueue) -+from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator - from vllm.executor.multiproc_worker_utils import ( - _add_prefix, set_multiprocessing_worker_envs) - from vllm.logger import init_logger -@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): - if self.max_concurrent_batches > 1: - # Note: must use only 1 IO thread to keep dequeue sequence - # from the response queue -+ # _async_aggregate_workers_output also assumes a single IO thread - self.io_thread_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="mp_exec_io") - - self.output_rank = self._get_output_rank() -+ self.has_connector = self.vllm_config.kv_transfer_config is not None -+ self.kv_output_aggregator = KVOutputAggregator( -+ self.parallel_config.world_size) - - def start_worker_monitor(self): - workers = self.workers -@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): - self, - scheduler_output, - ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: -- (output, ) = self.collective_rpc( -+ non_block = self.max_concurrent_batches > 1 -+ -+ if not self.has_connector or self.vllm_config.model_config.use_mla: -+ # get output only from a single worker (output_rank) -+ (output, ) = self.collective_rpc( -+ "execute_model", -+ args=(scheduler_output, ), -+ unique_reply_rank=self.output_rank, -+ non_block=non_block, -+ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -+ return output -+ -+ # get output from all workers -+ outputs = self.collective_rpc( - "execute_model", - args=(scheduler_output, ), -- unique_reply_rank=self.output_rank, -- non_block=self.max_concurrent_batches > 1, -+ non_block=non_block, - timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) -- return output -+ -+ # aggregate all workers output to a single output -+ if non_block: -+ return self.kv_output_aggregator.async_aggregate( -+ outputs, self.output_rank) -+ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) - - def collective_rpc(self, - method: Union[str, Callable], -diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 9e7e44d06..7117f60b5 100644 ---- a/vllm/v1/worker/gpu_worker.py -+++ b/vllm/v1/worker/gpu_worker.py -@@ -1,6 +1,7 @@ - # SPDX-License-Identifier: Apache-2.0 - # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - """A GPU worker class.""" -+import copy - import gc - import os - from typing import TYPE_CHECKING, Optional -@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator - from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce) --from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -+from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, -+ has_kv_transfer_group) - from vllm.distributed.parallel_state import get_pp_group, get_tp_group - from vllm.logger import init_logger - from vllm.lora.request import LoRARequest -@@ -24,7 +26,7 @@ from vllm.platforms import current_platform - from vllm.sequence import IntermediateTensors - from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling - from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec --from vllm.v1.outputs import ModelRunnerOutput -+from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput - from vllm.v1.utils import report_usage_stats - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - from vllm.v1.worker.worker_base import WorkerBase -@@ -313,9 +315,21 @@ class Worker(WorkerBase): - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) -- return None -+ if not has_kv_transfer_group(): -+ return None -+ -+ # In case of PP with kv transfer, we need to pass through the -+ # finished_sending and finished_recving buffers. -+ new_output = EMPTY_MODEL_RUNNER_OUTPUT -+ if output.finished_sending or output.finished_recving or output.finished_dumping: -+ new_output = copy.copy(new_output) -+ new_output.finished_sending = output.finished_sending -+ new_output.finished_recving = output.finished_recving -+ new_output.finished_dumping = output.finished_dumping -+ output = new_output -+ - assert isinstance(output, ModelRunnerOutput) -- return output if self.is_driver_worker else None -+ return output - - def profile(self, is_start: bool = True): - if self.profiler is None: --- -2.50.1.windows.1 - - -From 1c58b7c4d32dc726d87f8e0e34723c331e969b39 Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Sat, 6 Sep 2025 23:01:17 +0800 -Subject: [PATCH 3/3] [Feat] Adapted from pr 19330 to support recomputing load - failed reqs - -Signed-off-by: flesher0813 <1208954694@qq.com> ---- - .../kv_transfer/kv_connector/utils.py | 4 + - .../kv_transfer/kv_connector/v1/base.py | 9 ++ - .../v1/shared_storage_connector.py | 7 +- - vllm/v1/core/block_pool.py | 2 +- - vllm/v1/core/sched/scheduler.py | 126 ++++++++++++++++++ - vllm/v1/core/single_type_kv_cache_manager.py | 3 + - vllm/v1/outputs.py | 4 + - vllm/v1/worker/gpu_input_batch.py | 9 ++ - vllm/v1/worker/gpu_model_runner.py | 46 ++++++- - vllm/v1/worker/gpu_worker.py | 3 +- - 10 files changed, 203 insertions(+), 10 deletions(-) - -diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py -index 06e71f107..0fee7e74c 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/utils.py -+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py -@@ -158,6 +158,7 @@ class KVOutputAggregator: - - finished_sending = set[str]() - finished_recving = set[str]() -+ invalid_block_ids = set[int]() - finished_dumping: dict[str, list[str]] = {} - for output in outputs: - update_finished_set(output.finished_sending, -@@ -166,6 +167,8 @@ class KVOutputAggregator: - self._recv_remaining_count, finished_recving) - update_finished_list(output.finished_dumping, - self._dump_remaining_count, finished_dumping) -+ if output.invalid_block_ids: -+ invalid_block_ids |= output.invalid_block_ids - - # select output of the worker specified by output_rank - output = outputs[output_rank] -@@ -178,6 +181,7 @@ class KVOutputAggregator: - output.finished_sending = finished_sending if finished_sending else None - output.finished_recving = finished_recving if finished_recving else None - output.finished_dumping = finished_dumping if finished_dumping else None -+ output.invalid_block_ids = invalid_block_ids or None - - return output - diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -index f80b5eba2..2f0b73cb9 100644 +index f80b5eba2..61424b10d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -@@ -200,6 +200,15 @@ class KVConnectorBase_V1(ABC): +@@ -200,7 +200,16 @@ class KVConnectorBase_V1(ABC): call to this method (this call or a prior one). """ return None, None -+ +- ++ + def get_block_ids_with_load_errors(self) -> Optional[set[int]]: + """ + Get the set of block IDs that failed to load. @@ -431,9 +264,10 @@ index f80b5eba2..2f0b73cb9 100644 + Returns None if no errors occurred during load. + """ + return None - ++ # ============================== # Scheduler-side methods + # ============================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 3c574d065..223106def 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -445,10 +279,10 @@ index 3c574d065..223106def 100644 -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING - + import safetensors @@ -53,10 +53,7 @@ class ReqMeta: - + @dataclass class SharedStorageConnectorMetadata(KVConnectorMetadata): - requests: list[ReqMeta] @@ -456,7 +290,7 @@ index 3c574d065..223106def 100644 - def __init__(self): - self.requests = [] + requests: list[ReqMeta] = field(default_factory=list) - + def add_request( self, diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py @@ -472,32 +306,168 @@ index d21f94727..1800665c7 100644 return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(block_hashes) >= num_cached_blocks +diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py +index 6937455e7..c36a25bc5 100644 +--- a/vllm/v1/core/kv_cache_manager.py ++++ b/vllm/v1/core/kv_cache_manager.py +@@ -3,7 +3,7 @@ + + from collections import defaultdict + from dataclasses import dataclass +-from typing import Optional ++from typing import Optional, Union + + from vllm.distributed.kv_events import KVCacheEvent + from vllm.logger import init_logger +@@ -14,6 +14,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.metrics.stats import PrefixCacheStats + from vllm.v1.request import Request, RequestStatus ++from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.integration.vllm.ucm_sparse.base import INVALID_SLOT + + logger = init_logger(__name__) + +@@ -193,6 +195,7 @@ class KVCacheManager: + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, ++ num_slots_sparsed: Union[None, int] = None + ) -> Optional[KVCacheBlocks]: + """Add slots for a request with new tokens to append. + +@@ -231,6 +234,12 @@ class KVCacheManager: + """ + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") ++ if num_slots_sparsed != INVALID_SLOT: ++ return get_ucm_sparse().allocate_slots(request, ++ num_slots_sparsed, ++ self.coordinator, ++ self.block_pool, ++ self.kv_cache_config.kv_cache_groups) + + if new_computed_blocks is not None: + new_computed_block_list = new_computed_blocks.blocks +diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py +index d34f39327..a0ab878a5 100644 +--- a/vllm/v1/core/sched/output.py ++++ b/vllm/v1/core/sched/output.py +@@ -155,3 +155,6 @@ class SchedulerOutput: + + # KV Cache Connector metadata. + kv_connector_metadata: Optional[KVConnectorMetadata] = None ++ ++ # modified slots by sparse algorithm ++ req_sparsed_slots: dict[str, int] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 22c0ad8d6..b16e38423 100644 +index fe552db74..cb6f44227 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py -@@ -745,16 +745,28 @@ class Scheduler(SchedulerInterface): +@@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + from vllm.v1.spec_decode.metrics import SpecDecodingStats + from vllm.v1.structured_output import StructuredOutputManager ++from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized, get_ucm_sparse, has_ucm_sparse ++from ucm.integration.vllm.ucm_sparse.base import UcmSparseBase, UcmSparseRole, INVALID_SLOT + + logger = init_logger(__name__) + +@@ -79,12 +81,18 @@ class Scheduler(SchedulerInterface): + # will have a corresponding KVConnector with Role=WORKER. + # KV Connector pushes/pull of remote KVs for P/D and offloading. + self.connector = None ++ self.ucm_sparse = None + if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") + self.connector = KVConnectorFactory.create_connector_v1( + config=self.vllm_config, role=KVConnectorRole.SCHEDULER) ++ # Initialize UCM Sparse if available ++ if "ucm_sparse_config" in vllm_config.kv_transfer_config.kv_connector_extra_config: ++ ensure_ucm_sparse_initialized(vllm_config, role=UcmSparseRole.SCHEDULER) ++ self.ucm_sparse = get_ucm_sparse() ++ logger.info("UCM Sparse initialized successfully: {}".format(self.ucm_sparse)) + + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config, +@@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface): + + # First, schedule the RUNNING requests. + req_index = 0 ++ req_sparsed_slots: dict[str, int] = {} + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + num_new_tokens = (request.num_tokens_with_spec - + request.num_computed_tokens) +@@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface): + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, +- num_lookahead_tokens=self.num_lookahead_tokens) ++ num_lookahead_tokens=self.num_lookahead_tokens, ++ num_slots_sparsed=num_slots_sparsed) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. +@@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface): + break + + request = self.waiting.peek_request() ++ num_slots_sparsed = INVALID_SLOT ++ if self.ucm_sparse: ++ num_slots_sparsed = self.ucm_sparse.estimate_num_slots_sparsed(request) ++ req_sparsed_slots.update({request.request_id: num_slots_sparsed}) + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: +@@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface): + new_computed_blocks, + num_lookahead_tokens=self.num_lookahead_tokens, + delay_cache_blocks=load_kv_async, ++ num_slots_sparsed=num_slots_sparsed + ) + if new_blocks is None: + # The request cannot be scheduled. +@@ -559,6 +578,7 @@ class Scheduler(SchedulerInterface): + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, ++ req_sparsed_slots=req_sparsed_slots, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between +@@ -745,23 +765,38 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits + invalid_block_ids = model_runner_output.invalid_block_ids - + new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None -+ ++ + recovered_req_ids = None + if invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. + recovered_req_ids = self._handle_invalid_blocks(invalid_block_ids) - + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid # expensive operations inside the loop. for request in self.running: req_id = request.request_id ++ # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk ++ ++ + if recovered_req_ids and req_id in recovered_req_ids: + # Skip requests that were recovered from KV load failure + new_running.append(request) @@ -505,7 +475,62 @@ index 22c0ad8d6..b16e38423 100644 num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. -@@ -1115,3 +1127,117 @@ class Scheduler(SchedulerInterface): + new_running.append(request) + continue + +- req_index = model_runner_output.req_id_to_index[req_id] ++ req_index = model_runner_output.req_id_to_index[req_id] + generated_token_ids = sampled_token_ids[ + req_index] if sampled_token_ids else [] + +@@ -792,6 +827,12 @@ class Scheduler(SchedulerInterface): + new_token_ids = generated_token_ids + kv_transfer_params = None + ++ if model_runner_output.finished_dumping is not None: ++ request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) ++ ++ if request.num_output_tokens == 0 and (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens): ++ self.connector.connector.commit(request.succeed_dumped_blocks, True) ++ + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. +@@ -842,7 +883,6 @@ class Scheduler(SchedulerInterface): + spec_token_ids[req_index]) + else: + request.spec_token_ids = spec_token_ids[req_index] +- + # Get prompt logprobs for this request. + prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + if new_token_ids or pooler_output is not None \ +@@ -869,6 +909,7 @@ class Scheduler(SchedulerInterface): + + if not stopped: + new_running.append(request) ++ + self.running = new_running + + # KV Connector: update state for finished KV Transfers. +@@ -927,6 +968,8 @@ class Scheduler(SchedulerInterface): + def add_request(self, request: Request) -> None: + self.waiting.add_request(request) + self.requests[request.request_id] = request ++ if self.ucm_sparse: ++ self.ucm_sparse.request_begin(request.request_id, request.prompt_token_ids) + if self.log_stats: + request.record_event(EngineCoreEventType.QUEUED) + +@@ -976,6 +1019,8 @@ class Scheduler(SchedulerInterface): + + def _free_request(self, request: Request) -> Optional[dict[str, Any]]: + assert request.is_finished() ++ if self.ucm_sparse: ++ self.ucm_sparse.request_finished_in_scheduler(request.request_id) + + delay_free_blocks, kv_xfer_params = self._connector_finished(request) + self.encoder_cache_manager.free(request) +@@ -1113,3 +1158,117 @@ class Scheduler(SchedulerInterface): for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) @@ -623,7 +648,6 @@ index 22c0ad8d6..b16e38423 100644 + # Return the IDs of affected running requests to skip in + # update_from_output. + return {r.request_id for r in affected_requests} -\ No newline at end of file diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b4718038..d97690ae5 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py @@ -632,33 +656,140 @@ index 5b4718038..d97690ae5 100644 """ num_cached_blocks = self.num_cached_block[request.request_id] num_full_blocks = num_tokens // self.block_size -+ ++ + if num_cached_blocks >= num_full_blocks: + return - + self.block_pool.cache_full_blocks( request=request, +diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py +index b06b7cc80..61cd7110f 100644 +--- a/vllm/v1/executor/multiproc_executor.py ++++ b/vllm/v1/executor/multiproc_executor.py +@@ -26,6 +26,7 @@ from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) + from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) ++from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator + from vllm.executor.multiproc_worker_utils import ( + _add_prefix, set_multiprocessing_worker_envs) + from vllm.logger import init_logger +@@ -111,10 +112,14 @@ class MultiprocExecutor(Executor): + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue ++ # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() ++ self.has_connector = self.vllm_config.kv_transfer_config is not None ++ self.kv_output_aggregator = KVOutputAggregator( ++ self.parallel_config.world_size) + + def start_worker_monitor(self): + workers = self.workers +@@ -155,13 +160,30 @@ class MultiprocExecutor(Executor): + self, + scheduler_output, + ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: +- (output, ) = self.collective_rpc( ++ non_block = self.max_concurrent_batches > 1 ++ ++ if not self.has_connector or self.vllm_config.model_config.use_mla: ++ # get output only from a single worker (output_rank) ++ (output, ) = self.collective_rpc( ++ "execute_model", ++ args=(scheduler_output, ), ++ unique_reply_rank=self.output_rank, ++ non_block=non_block, ++ timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) ++ return output ++ ++ # get output from all workers ++ outputs = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), +- unique_reply_rank=self.output_rank, +- non_block=self.max_concurrent_batches > 1, ++ non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) +- return output ++ ++ # aggregate all workers output to a single output ++ if non_block: ++ return self.kv_output_aggregator.async_aggregate( ++ outputs, self.output_rank) ++ return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + + def collective_rpc(self, + method: Union[str, Callable], diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index c8388baed..8697150b2 100644 +index f78623f57..8697150b2 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py -@@ -108,6 +108,10 @@ class ModelRunnerOutput: +@@ -107,6 +107,11 @@ class ModelRunnerOutput: + # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None - finished_dumping: Optional[dict[str, list[str]]] = None -+ ++ finished_dumping: Optional[dict[str, list[str]]] = None ++ + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them. + invalid_block_ids: Optional[set[int]] = None - + # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None +diff --git a/vllm/v1/request.py b/vllm/v1/request.py +index 9b96f4599..825b77bba 100644 +--- a/vllm/v1/request.py ++++ b/vllm/v1/request.py +@@ -102,7 +102,7 @@ class Request: + # State + # The number of tokens with prefix cache hits. + self.num_cached_tokens = -1 +- ++ self.succeed_dumped_blocks: list[str] = [] + # The number of NaNs in logits. A value greater than 0 + # indicates that the output is corrupted + self.num_nans_in_logits = 0 +diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py +index 8f4e8d64c..733ac1f41 100644 +--- a/vllm/v1/worker/block_table.py ++++ b/vllm/v1/worker/block_table.py +@@ -60,6 +60,15 @@ class BlockTable: + start = self.num_blocks_per_row[row_idx] + self.num_blocks_per_row[row_idx] += num_blocks + self.block_table_np[row_idx, start:start + num_blocks] = block_ids ++ ++ def reset_row( ++ self, ++ row_idx: int, ++ ) -> None: ++ self.num_blocks_per_row[row_idx] = 0 ++ self.block_table[row_idx].fill_(0) ++ self.block_table_cpu[row_idx].fill_(0) ++ self.block_table_np[row_idx].fill(0) + + def add_row(self, block_ids: list[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 +@@ -116,6 +125,10 @@ class MultiGroupBlockTable: + row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): + block_table.append_row(block_ids[i], row_idx) ++ ++ def reset_row(self, row_idx: int) -> None: ++ for i, block_table in enumerate(self.block_tables): ++ block_table.reset_row(row_idx) + + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: + for i, block_table in enumerate(self.block_tables): diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be..0e65c98f5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -46,6 +46,11 @@ class CachedRequestState: - + def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) + # 'last_generator_offset' and 'last_gelen_last_output_token_ids' are @@ -666,7 +797,7 @@ index 1a79d72be..0e65c98f5 100644 + # invalid (e.g., due to KV load errors). + self.last_generator_offset = 0 if self.generator else None + self.len_last_output_token_ids = len(self.output_token_ids) - + @property def num_tokens(self) -> int: @@ -201,6 +206,7 @@ class InputBatch: @@ -674,7 +805,7 @@ index 1a79d72be..0e65c98f5 100644 # generator should not be included in the dictionary. self.generators: dict[int, torch.Generator] = {} + self.generators_last_offset: dict[int, int] = {} - + self.num_logprobs: dict[str, int] = {} # NOTE(rob): num_prompt_logprobs only includes reqs @@ -335,6 +341,9 @@ class InputBatch: @@ -684,16 +815,43 @@ index 1a79d72be..0e65c98f5 100644 + assert (request.last_generator_offset is not None) + self.generators_last_offset[ + req_index] = request.last_generator_offset - + if sampling_params.logprobs is not None: self.num_logprobs[req_id] = sampling_params.logprobs diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 14278bb6a..9471f7512 100644 +index 5a26e88db..2538bf0c2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py -@@ -475,6 +475,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager + from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, + sanity_check_mm_encoder_outputs, scatter_mm_placeholders) + ++from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.integration.vllm.ucm_sparse.base import UcmSparseMetadata, INVALID_SLOT ++ + if TYPE_CHECKING: + import xgrammar as xgr + import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 +@@ -365,6 +368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: ++ self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. +@@ -468,13 +472,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # Update the states of the running/resumed requests. + is_last_rank = get_pp_group().is_last_rank + req_data = scheduler_output.scheduled_cached_reqs ++ req_sparsed_slots = scheduler_output.req_sparsed_slots + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] - ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + # Update the cached states. + if (num_computed_tokens <= req_state.num_computed_tokens): + # The request was rescheduled after a KV load failure. Clear @@ -712,44 +870,142 @@ index 14278bb6a..9471f7512 100644 + req_index] - len_last_sampled + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx -+ ++ req_state.num_computed_tokens = num_computed_tokens - + if not is_last_rank: -@@ -492,6 +510,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -492,17 +516,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): elif num_new_tokens > 0: req_state.output_token_ids.extend( new_token_ids[-num_new_tokens:]) -+ ++ + req_state.len_last_output_token_ids = len( + req_state.output_token_ids) + if req_state.generator: + req_state.last_generator_offset = ( + req_state.generator.get_offset()) - + # Update the block IDs. - if not resumed_from_preemption: -@@ -511,6 +535,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +- if not resumed_from_preemption: +- # Append the new blocks to the existing block IDs. +- for block_ids, new_ids in zip(req_state.block_ids, +- new_block_ids): +- block_ids.extend(new_ids) +- else: ++ if resumed_from_preemption or is_sparsed_request: + # The request is resumed from preemption. + # Replace the existing block IDs with the new ones. + req_state.block_ids = new_block_ids ++ else: ++ # Append the new blocks to the existing block IDs. ++ for block_ids, new_ids in zip(req_state.block_ids, ++ new_block_ids): ++ block_ids.extend(new_ids) + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: +@@ -511,10 +541,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): # scheduled in the previous step and needs to be added again. req_ids_to_add.append(req_id) continue -+ ++ + if req_state.generator: + assert (req_state.last_generator_offset is not None) + self.input_batch.generators_last_offset[ + req_index] = req_state.last_generator_offset - + # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( -@@ -1381,6 +1410,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): - finished_dumping = self.maybe_wait_for_kv_save() + num_computed_tokens) ++ if is_sparsed_request: ++ self.input_batch.block_table.reset_row(req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) + + # For the last rank, we don't need to update the token_ids_cpu +@@ -622,7 +659,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) ++ ++ self.seq_lens_np[:num_reqs] = ( ++ self.input_batch.num_computed_tokens_cpu[:num_reqs] + ++ num_scheduled_tokens) + ++ # TODO: improve performance, no `positions_np.copy()` ++ sparsed_positions = positions_np.copy() ++ req_sparsed_slots = scheduler_output.req_sparsed_slots ++ for req_id in self.input_batch.req_id_to_index: ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ req_index = self.input_batch.req_id_to_index[req_id] ++ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP ++ if is_sparsed_request: ++ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] +@@ -652,11 +702,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + +- positions_np // block_size) ++ sparsed_positions // block_size) + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = block_table_cpu.flatten( + )[block_table_indices].numpy() +- block_offsets = positions_np % block_size ++ block_offsets = sparsed_positions % block_size + np.add( + block_numbers * block_size, + block_offsets, +@@ -666,9 +716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + +- self.seq_lens_np[:num_reqs] = ( +- self.input_batch.num_computed_tokens_cpu[:num_reqs] + +- num_scheduled_tokens) ++ for req_id in self.input_batch.req_id_to_index: ++ req_index = self.input_batch.req_id_to_index[req_id] ++ is_sparsed_request = scheduler_output.req_sparsed_slots[req_id] != INVALID_SLOT ++ if is_sparsed_request: ++ self.seq_lens_np[req_index] = scheduler_output.req_sparsed_slots[req_id] + + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( +@@ -680,6 +732,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + non_blocking=True) + else: + # Common case (1D positions) ++ self.positions_cpu[:total_num_scheduled_tokens] = torch.from_numpy( ++ positions_np[:total_num_scheduled_tokens]) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], + non_blocking=True) +@@ -1370,7 +1424,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): + skip_cuda_graphs=skip_cuda_graphs, + ): + self.maybe_setup_kv_connector(scheduler_output) +- ++ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) ++ + model_output = self.model( + input_ids=input_ids, + positions=positions, +@@ -1378,9 +1433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): + inputs_embeds=inputs_embeds, + ) + +- self.maybe_wait_for_kv_save() ++ finished_dumping = self.maybe_wait_for_kv_save() ++ self.maybe_execute_ucm_sparse_finished() ++ finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + invalid_block_ids = self.get_block_ids_with_load_errors() - + if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output -@@ -1474,7 +1504,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1474,7 +1532,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # This relies on cuda-specific torch-internal impl details generator = self.input_batch.generators.get(i) if generator is not None: @@ -759,64 +1015,144 @@ index 14278bb6a..9471f7512 100644 # Record the index of the request that should not be sampled, # so that we could clear the sampled tokens before returning. discard_sampled_tokens_req_indices.append(i) -@@ -1563,7 +1594,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1563,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, -- finished_dumping=finished_dumping + finished_dumping=finished_dumping, + invalid_block_ids = invalid_block_ids ) - + def propose_draft_token_ids( -@@ -1694,13 +1726,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1693,13 +1754,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + invalid_block_ids = self.get_block_ids_with_load_errors() + get_kv_transfer_group().clear_connector_metadata() - + - if not finished_sending and not finished_recving: + if not finished_sending and not finished_recving and not invalid_block_ids: return EMPTY_MODEL_RUNNER_OUTPUT - + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving + output.invalid_block_ids = invalid_block_ids return output - + @staticmethod -@@ -1732,6 +1767,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): - return get_kv_transfer_group().get_finished( +@@ -1719,9 +1783,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod +- def maybe_wait_for_kv_save() -> None: ++ def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: + if has_kv_transfer_group(): +- get_kv_transfer_group().wait_for_save() ++ return get_kv_transfer_group().wait_for_save() ++ ++ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) ++ ucm_sparse.execute_begin(scheduler_output) ++ ++ def maybe_execute_ucm_sparse_finished(self): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.execute_finished() ++ ++ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.request_finished_in_worker(request_id) + + @staticmethod + def get_finished_kv_transfers( +@@ -1732,6 +1815,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None -+ + + def get_block_ids_with_load_errors(self) -> Optional[set[int]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_block_ids_with_load_errors() + return None - ++ def propose_ngram_draft_token_ids( self, + sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 7117f60b5..1b816b25b 100644 +index 9e7e44d06..d52a49a2e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py -@@ -321,11 +321,12 @@ class Worker(WorkerBase): - # In case of PP with kv transfer, we need to pass through the - # finished_sending and finished_recving buffers. - new_output = EMPTY_MODEL_RUNNER_OUTPUT -- if output.finished_sending or output.finished_recving or output.finished_dumping: +@@ -1,6 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + """A GPU worker class.""" ++import copy + import gc + import os + from typing import TYPE_CHECKING, Optional +@@ -15,7 +16,8 @@ from vllm.device_allocator.cumem import CuMemAllocator + from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +-from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized ++from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, ++ has_kv_transfer_group) + from vllm.distributed.parallel_state import get_pp_group, get_tp_group + from vllm.logger import init_logger + from vllm.lora.request import LoRARequest +@@ -24,10 +26,11 @@ from vllm.platforms import current_platform + from vllm.sequence import IntermediateTensors + from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec +-from vllm.v1.outputs import ModelRunnerOutput ++from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + from vllm.v1.utils import report_usage_stats + from vllm.v1.worker.gpu_model_runner import GPUModelRunner + from vllm.v1.worker.worker_base import WorkerBase ++from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized + + logger = init_logger(__name__) + +@@ -313,9 +316,22 @@ class Worker(WorkerBase): + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) +- return None ++ if not has_kv_transfer_group(): ++ return None ++ ++ # In case of PP with kv transfer, we need to pass through the ++ # finished_sending and finished_recving buffers. ++ new_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving or output.finished_dumping or output.invalid_block_ids: - new_output = copy.copy(new_output) - new_output.finished_sending = output.finished_sending - new_output.finished_recving = output.finished_recving - new_output.finished_dumping = output.finished_dumping ++ new_output = copy.copy(new_output) ++ new_output.finished_sending = output.finished_sending ++ new_output.finished_recving = output.finished_recving ++ new_output.finished_dumping = output.finished_dumping + new_output.invalid_block_ids = output.invalid_block_ids - output = new_output - ++ output = new_output ++ assert isinstance(output, ModelRunnerOutput) --- -2.50.1.windows.1 +- return output if self.is_driver_worker else None ++ return output + + def profile(self, is_start: bool = True): + if self.profiler is None: +@@ -386,6 +402,7 @@ def init_worker_distributed_environment( + parallel_config.pipeline_parallel_size) + + ensure_kv_transfer_initialized(vllm_config) ++ ensure_ucm_sparse_initialized(vllm_config) + + + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): +-- +2.34.1 diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt-sparse.patch b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt-sparse.patch deleted file mode 100644 index 4d8c9786..00000000 --- a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt-sparse.patch +++ /dev/null @@ -1,267 +0,0 @@ -From 32e09eba672e6f55b67962bfef63eca9b67fe508 Mon Sep 17 00:00:00 2001 -From: hek14 <1023129548@qq.com> -Date: Mon, 25 Aug 2025 15:48:40 +0800 -Subject: [PATCH] ucm_sparse patch - ---- - vllm_ascend/attention/attention_v1.py | 42 ++++++++++++++++++ - vllm_ascend/worker/model_runner_v1.py | 63 +++++++++++++++++++++++---- - vllm_ascend/worker/worker_v1.py | 2 + - 3 files changed, 98 insertions(+), 9 deletions(-) - -diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py -index 915feb7..0600f35 100644 ---- a/vllm_ascend/attention/attention_v1.py -+++ b/vllm_ascend/attention/attention_v1.py -@@ -36,6 +36,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill - from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, - nd_to_nz_2d, nd_to_nz_spec) - -+from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse -+ - - class AscendAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True -@@ -453,6 +455,8 @@ def unified_ascend_attention_with_output( - attn_metadata = forward_context.attn_metadata - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] -+ -+ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) - self.impl.forward(self, - query, - key, -@@ -461,6 +465,7 @@ def unified_ascend_attention_with_output( - attn_metadata, - output, - trace_flag=False) -+ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return - -@@ -492,6 +497,43 @@ def maybe_save_kv_layer_to_connector( - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata) - -+def maybe_execute_sparse_attention_begin( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) -+ -+def maybe_execute_sparse_attention_finished( -+ query: torch.Tensor, -+ key: torch.Tensor, -+ value: torch.Tensor, -+ attn_output: torch.Tensor, -+ layer_name: str, -+ forward_context: ForwardContext, -+): -+ if not has_ucm_sparse(): -+ return -+ -+ ucm_sparse = get_ucm_sparse() -+ -+ attn_metadata = forward_context.attn_metadata -+ if attn_metadata is None: -+ return -+ -+ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) -+ - def unified_attention_with_output_fake( - query: torch.Tensor, - key: torch.Tensor, -diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py -index f9cca93..31766b2 100644 ---- a/vllm_ascend/worker/model_runner_v1.py -+++ b/vllm_ascend/worker/model_runner_v1.py -@@ -91,6 +91,9 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer - from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer - from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch - -+from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse -+from ucm.integration.vllm.ucm_sparse.base import UcmSparseMetadata, INVALID_SLOT -+ - if TYPE_CHECKING: - import xgrammar as xgr # type: ignore[import-untyped] - from vllm.v1.core.sched.output import SchedulerOutput -@@ -350,6 +353,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: -+ self.ucm_sparse_request_finished_in_worker(req_id) - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) - # Remove the finished requests from the persistent batch. -@@ -456,12 +460,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): - - # Update the states of the running/resumed requests. - req_data = scheduler_output.scheduled_cached_reqs -+ req_sparsed_slots = scheduler_output.req_sparsed_slots - is_last_rank = get_pp_group().is_last_rank - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT - - req_state.num_computed_tokens = num_computed_tokens - if not is_last_rank: -@@ -477,15 +483,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) - # Update the block IDs. -- if not resumed_from_preemption: -+ if resumed_from_preemption or is_sparsed_request: -+ # The request is resumed from preemption. -+ # Replace the existing block IDs with the new ones. -+ req_state.block_ids = new_block_ids -+ else: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip( # type: ignore[call-overload] - req_state.block_ids, new_block_ids): - block_ids.extend(new_ids) -- else: -- # The request is resumed from preemption. -- # Replace the existing block IDs with the new ones. -- req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: -@@ -499,6 +505,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) - -+ if is_sparsed_request: -+ self.input_batch.block_table.reset_row(req_index) -+ - self.input_batch.block_table.append_row(new_block_ids, req_index) - - if not is_last_rank: -@@ -959,12 +968,21 @@ class NPUModelRunner(LoRAModelRunnerMixin): - num_scheduled_tokens) - seq_lens = self.seq_lens_cpu[:num_reqs] - -+ # TODO: improve performance, no `positions_np.copy()` -+ sparsed_positions = positions_np.copy() -+ req_sparsed_slots = scheduler_output.req_sparsed_slots -+ for req_id in self.input_batch.req_id_to_index: -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT -+ req_index = self.input_batch.req_id_to_index[req_id] -+ if is_sparsed_request: -+ sparsed_positions[req_index] -= seq_lens[req_index] - req_sparsed_slots[req_id] -+ - block_table_indices = (req_indices * self.max_num_blocks_per_req + -- positions_np // self.block_size) -+ sparsed_positions // self.block_size) - - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() -- block_offsets = positions_np % self.block_size -+ block_offsets = sparsed_positions % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) -@@ -989,10 +1007,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): - else: - attn_state = AscendAttentionState.PrefillCacheHit - -+ for req_id in self.input_batch.req_id_to_index: -+ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT -+ req_index = self.input_batch.req_id_to_index[req_id] -+ if is_sparsed_request: -+ seq_lens[req_index] = req_sparsed_slots[req_id] -+ - self.attn_mask = self._make_attention_mask( - seq_lens=seq_lens, - query_lens=num_scheduled_tokens, -- position=positions, -+ position=torch.tensor(sparsed_positions).npu(), - attn_state=attn_state) - self.attn_state = attn_state # type: ignore - -@@ -1131,6 +1155,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - maybe_converting_weight_acl_format(self.model, - ACL_FORMAT_FRACTAL_ND) - self.maybe_setup_kv_connector(scheduler_output) -+ self.maybe_execute_ucm_sparse_begin(scheduler_output) - - hidden_states = self.model( - input_ids=input_ids, -@@ -1140,6 +1165,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - **model_kwargs, - ) - finished_dumping = self.maybe_wait_for_kv_save() -+ self.maybe_execute_ucm_sparse_finished() - - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 -@@ -2377,7 +2403,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - if batch_size <= padded_batch_size < selected_batch_size: - selected_batch_size = padded_batch_size - return selected_batch_size -- -+ - @staticmethod - def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): - # Update KVConnector with the KVConnector metadata forward(). -@@ -2398,3 +2424,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): - def maybe_wait_for_kv_save(): - if has_kv_transfer_group(): - return get_kv_transfer_group().wait_for_save() -+ -+ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput"): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch) -+ ucm_sparse.execute_begin(scheduler_output) -+ -+ def maybe_execute_ucm_sparse_finished(self): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.execute_finished() -+ -+ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): -+ if not has_ucm_sparse(): -+ return -+ ucm_sparse = get_ucm_sparse() -+ ucm_sparse.request_finished_in_worker(request_id) -\ No newline at end of file -diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py -index df03d50..6ea5bf3 100644 ---- a/vllm_ascend/worker/worker_v1.py -+++ b/vllm_ascend/worker/worker_v1.py -@@ -49,6 +49,7 @@ from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist, - read_kv_cache_bytes_from_file, - sleep_mode_enabled, try_register_lib) - from vllm_ascend.worker.model_runner_v1 import NPUModelRunner -+from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized - - - class NPUWorker(WorkerBase): -@@ -321,6 +322,7 @@ class NPUWorker(WorkerBase): - parallel_config.world_size_across_dp, - ) - ensure_kv_transfer_initialized(self.vllm_config) -+ ensure_ucm_sparse_initialized(self.vllm_config) - - def _init_profiler(self): - # Torch profiler. Enabled and configured through env vars: --- -2.50.1.windows.1 - diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch index 6c4ca411..91c68dda 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-ascend-adapt.patch @@ -1,16 +1,16 @@ -From d5c47a5c2620843cb1af0277ff17768f5e20e057 Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Mon, 28 Jul 2025 10:58:23 +0800 -Subject: [PATCH 1/2] [Feature]:Add support for the vLLM V1 connector +From 67b10fc431e5aac0155ca5b77cd9a99e35656521 Mon Sep 17 00:00:00 2001 +From: wenxinwang +Date: Thu, 25 Sep 2025 05:31:48 -0700 +Subject: [PATCH] UCM adaptor -Signed-off-by: flesher0813 <1208954694@qq.com> --- - vllm_ascend/attention/attention_v1.py | 33 ++++++++++++++++++++++++ - vllm_ascend/worker/model_runner_v1.py | 37 ++++++++++++++++++++++++--- - 2 files changed, 66 insertions(+), 4 deletions(-) + vllm_ascend/attention/attention_v1.py | 75 ++++++++++++++++++++ + vllm_ascend/worker/model_runner_v1.py | 99 +++++++++++++++++++++++---- + vllm_ascend/worker/worker_v1.py | 25 +++++-- + 3 files changed, 183 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py -index 7d7f488..915feb7 100644 +index 7d7f488..09c4345 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,9 @@ import torch_npu @@ -23,7 +23,16 @@ index 7d7f488..915feb7 100644 from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput -@@ -444,6 +447,8 @@ def unified_ascend_attention_with_output( +@@ -33,6 +36,8 @@ from vllm_ascend.ops.attention import vanilla_chunked_prefill + from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, + nd_to_nz_2d, nd_to_nz_spec) + ++from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse ++ + + class AscendAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True +@@ -444,10 +449,14 @@ def unified_ascend_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: @@ -32,13 +41,20 @@ index 7d7f488..915feb7 100644 forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] -@@ -456,8 +461,36 @@ def unified_ascend_attention_with_output( + kv_cache = self.kv_cache[forward_context.virtual_engine] ++ ++ maybe_execute_sparse_attention_begin(query, key, value, layer_name, forward_context) + self.impl.forward(self, + query, + key, +@@ -456,8 +465,74 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) ++ maybe_execute_sparse_attention_finished(query, key, value, output, layer_name, forward_context) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return - + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return @@ -66,11 +82,48 @@ index 7d7f488..915feb7 100644 + return + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata) - ++ ++def maybe_execute_sparse_attention_begin( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_begin(query, key, value, layer_name, forward_context) ++ ++def maybe_execute_sparse_attention_finished( ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ attn_output: torch.Tensor, ++ layer_name: str, ++ forward_context: ForwardContext, ++): ++ if not has_ucm_sparse(): ++ return ++ ++ ucm_sparse = get_ucm_sparse() ++ ++ attn_metadata = forward_context.attn_metadata ++ if attn_metadata is None: ++ return ++ ++ ucm_sparse.attention_finished(query, key, value, attn_output, layer_name, forward_context) + def unified_attention_with_output_fake( query: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py -index eabcdbc..f9cca93 100644 +index eabcdbc..e51f46e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -39,7 +39,10 @@ from vllm.config import CompilationLevel, VllmConfig @@ -85,7 +138,71 @@ index eabcdbc..f9cca93 100644 from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE -@@ -876,7 +879,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -88,6 +91,9 @@ from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer + from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer + from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch + ++from ucm.integration.vllm.ucm_sparse.state import get_ucm_sparse, has_ucm_sparse ++from ucm.integration.vllm.ucm_sparse.base import UcmSparseMetadata, INVALID_SLOT ++ + if TYPE_CHECKING: + import xgrammar as xgr # type: ignore[import-untyped] + from vllm.v1.core.sched.output import SchedulerOutput +@@ -347,6 +353,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): + """ + # Remove finished requests from the cached states. + for req_id in scheduler_output.finished_req_ids: ++ self.ucm_sparse_request_finished_in_worker(req_id) + self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + # Remove the finished requests from the persistent batch. +@@ -453,12 +460,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): + + # Update the states of the running/resumed requests. + req_data = scheduler_output.scheduled_cached_reqs ++ req_sparsed_slots = scheduler_output.req_sparsed_slots + is_last_rank = get_pp_group().is_last_rank + for i, req_id in enumerate(req_data.req_ids): + req_state = self.requests[req_id] + num_computed_tokens = req_data.num_computed_tokens[i] + new_block_ids = req_data.new_block_ids[i] + resumed_from_preemption = req_data.resumed_from_preemption[i] ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT + + req_state.num_computed_tokens = num_computed_tokens + if not is_last_rank: +@@ -474,15 +483,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) + # Update the block IDs. +- if not resumed_from_preemption: ++ if resumed_from_preemption or is_sparsed_request: ++ # The request is resumed from preemption. ++ # Replace the existing block IDs with the new ones. ++ req_state.block_ids = new_block_ids ++ else: + # Append the new blocks to the existing block IDs. + for block_ids, new_ids in zip( # type: ignore[call-overload] + req_state.block_ids, new_block_ids): + block_ids.extend(new_ids) +- else: +- # The request is resumed from preemption. +- # Replace the existing block IDs with the new ones. +- req_state.block_ids = new_block_ids + + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is None: +@@ -496,6 +505,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): + self.input_batch.num_computed_tokens_cpu[req_index] = ( + num_computed_tokens) + ++ if is_sparsed_request: ++ self.input_batch.block_table.reset_row(req_index) ++ + self.input_batch.block_table.append_row(new_block_ids, req_index) + + if not is_last_rank: +@@ -876,7 +888,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, @@ -95,61 +212,106 @@ index eabcdbc..f9cca93 100644 # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 -@@ -1100,6 +1104,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -955,12 +968,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): + num_scheduled_tokens) + seq_lens = self.seq_lens_cpu[:num_reqs] + ++ # TODO: improve performance, no `positions_np.copy()` ++ sparsed_positions = positions_np.copy() ++ req_sparsed_slots = scheduler_output.req_sparsed_slots ++ for req_id in self.input_batch.req_id_to_index: ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ req_index = self.input_batch.req_id_to_index[req_id] ++ offset = 0 if req_index == 0 else cu_num_tokens[req_index - 1] # TODO: support MTP ++ if is_sparsed_request: ++ sparsed_positions[offset] = req_sparsed_slots[req_id] - 1 ++ + block_table_indices = (req_indices * self.max_num_blocks_per_req + +- positions_np // self.block_size) ++ sparsed_positions // self.block_size) + + block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() +- block_offsets = positions_np % self.block_size ++ block_offsets = sparsed_positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) +@@ -985,10 +1008,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): + else: + attn_state = AscendAttentionState.PrefillCacheHit + ++ for req_id in self.input_batch.req_id_to_index: ++ is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT ++ req_index = self.input_batch.req_id_to_index[req_id] ++ if is_sparsed_request: ++ seq_lens[req_index] = req_sparsed_slots[req_id] ++ + self.attn_mask = self._make_attention_mask( + seq_lens=seq_lens, + query_lens=num_scheduled_tokens, +- position=positions, ++ position=torch.tensor(sparsed_positions).npu(), + attn_state=attn_state) + self.attn_state = attn_state # type: ignore + +@@ -1100,6 +1129,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.positions[:padded_batch_size] - + # Run forward pass + finished_dumping = None with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): -@@ -1125,6 +1130,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -1125,6 +1155,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): assert self.model is not None maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + self.maybe_setup_kv_connector(scheduler_output) - ++ self.maybe_execute_ucm_sparse_begin(scheduler_output, attn_metadata) + hidden_states = self.model( input_ids=input_ids, -@@ -1133,6 +1139,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -1133,6 +1165,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, **model_kwargs, ) + finished_dumping = self.maybe_wait_for_kv_save() - ++ self.maybe_execute_ucm_sparse_finished() + use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 -@@ -1163,7 +1170,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): - +@@ -1163,7 +1197,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): + return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens) + num_scheduled_tokens, finished_dumping) - + def _get_cumsum_and_arange( self, -@@ -1400,7 +1407,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -1400,7 +1434,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): return EMPTY_MODEL_RUNNER_OUTPUT (attn_metadata, hidden_states, spec_decode_metadata, positions, num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens_np) = (self._process_reqs( + num_scheduled_tokens_np, finished_dumping) = (self._process_reqs( scheduler_output, intermediate_tensors)) - + with ProfileExecuteDuration().capture_async("post process"): -@@ -1561,6 +1568,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -1561,6 +1595,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], + finished_dumping=finished_dumping ) - + durations = ProfileExecuteDuration().pop_captured_sync() -@@ -2369,3 +2377,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): +@@ -2369,3 +2404,43 @@ class NPUModelRunner(LoRAModelRunnerMixin): if batch_size <= padded_batch_size < selected_batch_size: selected_batch_size = padded_batch_size return selected_batch_size -+ ++ + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). @@ -170,31 +332,37 @@ index eabcdbc..f9cca93 100644 + def maybe_wait_for_kv_save(): + if has_kv_transfer_group(): + return get_kv_transfer_group().wait_for_save() --- -2.50.1.windows.1 - - -From 0501efb489472b1a08a9447d078f6b9716c8c843 Mon Sep 17 00:00:00 2001 -From: flesher0813 <1208954694@qq.com> -Date: Sat, 30 Aug 2025 19:45:52 +0800 -Subject: [PATCH 2/2] [BugFix] Modify npu worker for aggregating - modelrunner_outputs - ---- - vllm_ascend/worker/worker_v1.py | 23 +++++++++++++++++++---- - 1 file changed, 19 insertions(+), 4 deletions(-) - ++ ++ def maybe_execute_ucm_sparse_begin(self, scheduler_output: "SchedulerOutput", attn_metadata: CommonAttentionMetadata): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.build_sparse_meta(scheduler_output, self.requests, self.input_batch, attn_metadata) ++ ucm_sparse.execute_begin(scheduler_output) ++ ++ def maybe_execute_ucm_sparse_finished(self): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.execute_finished() ++ ++ def ucm_sparse_request_finished_in_worker(self, request_id: str | int): ++ if not has_ucm_sparse(): ++ return ++ ucm_sparse = get_ucm_sparse() ++ ucm_sparse.request_finished_in_worker(request_id) +\ No newline at end of file diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py -index df03d50..e165506 100644 +index df03d50..a854923 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # - + +import copy from typing import Optional - + import torch @@ -27,7 +28,8 @@ from vllm import envs from vllm.config import VllmConfig @@ -213,9 +381,17 @@ index df03d50..e165506 100644 -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase - + import vllm_ascend.envs as envs_ascend -@@ -222,9 +224,22 @@ class NPUWorker(WorkerBase): +@@ -49,6 +51,7 @@ from vllm_ascend.utils import (check_kv_cache_bytes_cache_exist, + read_kv_cache_bytes_from_file, + sleep_mode_enabled, try_register_lib) + from vllm_ascend.worker.model_runner_v1 import NPUModelRunner ++from ucm.integration.vllm.ucm_sparse.state import ensure_ucm_sparse_initialized + + + class NPUWorker(WorkerBase): +@@ -222,9 +225,22 @@ class NPUWorker(WorkerBase): assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) @@ -237,9 +413,17 @@ index df03d50..e165506 100644 assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output - + def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: --- -2.50.1.windows.1 +@@ -321,6 +337,7 @@ class NPUWorker(WorkerBase): + parallel_config.world_size_across_dp, + ) + ensure_kv_transfer_initialized(self.vllm_config) ++ ensure_ucm_sparse_initialized(self.vllm_config) + + def _init_profiler(self): + # Torch profiler. Enabled and configured through env vars: +-- +2.34.1 diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index 33fe55ed..b6657e90 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -734,7 +734,7 @@ def request_finished( block_info = self.request_block_infos.pop(request.request_id, None) if hasattr(request, "succeed_dumped_blocks") and request.succeed_dumped_blocks: logger.debug(f"commit {request.succeed_dumped_blocks} to True.") - self.connector.commit(request.succeed_dumped_blocks, True) + # self.connector.commit(request.succeed_dumped_blocks, True) if block_info is not None: cancel_blocks = [ block_info.block_hashes[i] diff --git a/ucm/integration/vllm/ucm_sparse/base.py b/ucm/integration/vllm/ucm_sparse/base.py index c74f56ae..918c7e71 100644 --- a/ucm/integration/vllm/ucm_sparse/base.py +++ b/ucm/integration/vllm/ucm_sparse/base.py @@ -189,12 +189,14 @@ def update_state_after_alloc(self, request: Request, num_blocks: int): pass def build_sparse_meta( - self, - scheduler_output, - requests, - input_batch, + self, scheduler_output, requests, input_batch, attn_metadata ) -> UcmSparseMetadata: """ Build the sparse metadata for this step. """ pass + + def allocate_slots( + self, request, num_slots_sparsed, coordinator, block_pool, kv_cache_groups + ): + pass diff --git a/ucm/integration/vllm/ucm_sparse/factory.py b/ucm/integration/vllm/ucm_sparse/factory.py index 220a5c72..79692035 100644 --- a/ucm/integration/vllm/ucm_sparse/factory.py +++ b/ucm/integration/vllm/ucm_sparse/factory.py @@ -30,9 +30,10 @@ def loader() -> type[UcmSparseBase]: def create_sparse_method( cls, config: "VllmConfig", role: UcmSparseRole ) -> UcmSparseBase: - sparse_method_name = config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_method" - ] + ucm_cfg = config.kv_transfer_config.kv_connector_extra_config.get( + "ucm_sparse_config" + ) + sparse_method_name, _ = next(iter(ucm_cfg.items())) if sparse_method_name in cls._registry: sparse_method_cls = cls._registry[sparse_method_name]() else: diff --git a/ucm/integration/vllm/ucm_sparse/state.py b/ucm/integration/vllm/ucm_sparse/state.py index fc987926..4b6dacc2 100644 --- a/ucm/integration/vllm/ucm_sparse/state.py +++ b/ucm/integration/vllm/ucm_sparse/state.py @@ -38,13 +38,13 @@ def ensure_ucm_sparse_initialized( # Check if UCM sparse is enabled if ( - "ucm_sparse_method" + "ucm_sparse_config" not in vllm_config.kv_transfer_config.kv_connector_extra_config ): return sparse_method_name = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_method" + "ucm_sparse_config" ] if _UCM_SPARSE_AGENT is None: diff --git a/ucm/sandbox/sparse/kvcomp/README.md b/ucm/sandbox/sparse/kvcomp/README.md index 64c81db0..b010e7d9 100644 --- a/ucm/sandbox/sparse/kvcomp/README.md +++ b/ucm/sandbox/sparse/kvcomp/README.md @@ -101,7 +101,7 @@ python ucm/sandbox/sparse/kvcomp/offline_inference_kvcomp.py ``` ### Basic Usage -Similr to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_method` to be `KVComp` and specify a KVComp config file in `kvcomp_config_path`, as shown below. +Similr to UCM's `offline_inference_esa.py` examples. We only need to specify `ucm_sparse_config` to be `KVComp` and specify a KVComp config file in `kvcomp_config_path`, as shown below. ```python ... @@ -115,7 +115,7 @@ ktc = KVTransferConfig( "max_cache_size": 5368709120, "kv_block_size": 262144, }, - "ucm_sparse_method": "KvComp", + "ucm_sparse_config": "KvComp", "kvcomp_config_path": "configs/kvcomp_qwen3_4B_config.json", }, ) diff --git a/ucm/ucm_sparse/esa.py b/ucm/ucm_sparse/esa.py index 6a9c6a5c..2261032d 100644 --- a/ucm/ucm_sparse/esa.py +++ b/ucm/ucm_sparse/esa.py @@ -1,13 +1,19 @@ +import hashlib import math -import time +import pickle from dataclasses import dataclass -from functools import wraps -from typing import Dict, List, Union +from functools import cache +from typing import Dict, List, Optional, Union +import numpy as np import torch +from numpy.typing import NDArray from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import ForwardContext from vllm.sequence import SequenceStage +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import NONE_HASH from vllm.v1.request import Request from ucm.integration.vllm.ucm_sparse.base import ( @@ -16,51 +22,56 @@ UcmSparseMetadata, UcmSparseRole, ) -from ucm.store.connector.factory import UcmConnectorFactory -from ucm.store.connector.ucmstore import Task, UcmKVStoreBase +from ucm.store.base import Task, UcmKVStoreBase +from ucm.ucm_sparse.retrieval import retrieval_backend +from ucm.ucm_sparse.retrieval.retrieval_worker import RetrievalWorker +ReqType = Union[str, int] +HashType = Union[str, int] -def stat(func): - @wraps(func) - def wrapper(*args, **kwargs): - wrapper.call_count += 1 - start = time.perf_counter_ns() - result = func(*args, **kwargs) - end = time.perf_counter_ns() - cost = end - start - wrapper.time_costs.append(cost) - return result +data = None - wrapper.call_count = 0 - wrapper.time_costs = [] - return wrapper +class ReprePool: + def __init__(self, num_slots): + self.free_slots = set(range(num_slots)) + self.allocated = set() -ReqType = Union[str, int] -HashType = Union[str, int] + def allocate(self, num_new_slots): + assert len(self.free_slots) >= num_new_slots, "Not enough free slots" + allocated = list(self.free_slots)[:num_new_slots] + self.free_slots.difference_update(allocated) + self.allocated.update(allocated) + return allocated -# TODO: add ESA specific config in kv_transfer_config -> extra_config -INIT_WINDOW_SZ = 1 -LOCAL_WINDOW_SZ = 2 -SPARSE_RATIO = 0.3 -RETRIEVAL_STRIDE = 4 + def free(self, slots): + self.free_slots.update(slots) + self.allocated.difference_update(slots) @dataclass class ReqMeta: request_id: ReqType index_in_batch: int - num_prompt_tokens: int - num_output_tokens: int num_scheduled_tokens: int num_computed_tokens: int - num_sparsed_tokens: int vllm_block_ids: list[int] + query_start_loc: int + prompt_token_ids: list[int] + output_token_ids: list[int] @property def step(self) -> int: return self.num_output_tokens + @property + def num_prompt_tokens(self) -> int: + return len(self.prompt_token_ids) + + @property + def num_output_tokens(self) -> int: + return len(self.output_token_ids) + @property def stage(self) -> SequenceStage: return ( @@ -90,26 +101,28 @@ def add_request( self, request_id: ReqType, index_in_batch: int, - num_prompt_tokens: int, - num_output_tokens: int, num_scheduled_tokens: int, num_computed_tokens: int, - num_sparsed_tokens: int, vllm_block_ids: list[int], + query_start_loc: int, + prompt_token_ids: list[int], + output_token_ids: list[int], ) -> None: + meta = ReqMeta( request_id=request_id, index_in_batch=index_in_batch, - num_prompt_tokens=num_prompt_tokens, - num_output_tokens=num_output_tokens, num_scheduled_tokens=num_scheduled_tokens, num_computed_tokens=num_computed_tokens, - num_sparsed_tokens=num_sparsed_tokens, vllm_block_ids=vllm_block_ids, + query_start_loc=query_start_loc, + prompt_token_ids=prompt_token_ids, + output_token_ids=output_token_ids, ) self.requests.append(meta) +@cache def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: block_size, num_key_heads_per_tp, head_size = block_shape k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision @@ -123,6 +136,25 @@ def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> return v_offset if is_v else k_offset +@cache +def md5(input) -> int: + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + md5_bytes = hashlib.md5(input_bytes).digest() + return int.from_bytes(md5_bytes, byteorder="big") + + +@cache +def block_hash_func(parent_block_hash, curr_block_token_ids, extra_keys): + if not parent_block_hash: + parent_block_hash = NONE_HASH + curr_block_token_ids_tuple = tuple(curr_block_token_ids) + return md5((parent_block_hash, curr_block_token_ids_tuple, extra_keys)) + + +def task_hash_func(block_ids, store_type, tensor_type): + return hash((tuple(block_ids), store_type, tensor_type)) + + class ReqStatePerLayer: # handle single request per layer @@ -133,98 +165,71 @@ def __init__( rank: int, tp_size: int, store_instance: UcmKVStoreBase, + vllm_config: VllmConfig, + retrieval_worker: Optional[RetrievalWorker] = None, + repre_pool: Optional[ReprePool] = None, ): self.layer_name = layer_name self.layer_id = int(layer_name.split(".")[2]) - self.block_repre: torch.Tensor = ( - None ## shape: blks, num_key_heads_per_tp, head_size - ) - self.init_window: tuple[torch.Tensor, torch.Tensor] = None - self.local_window: tuple[torch.Tensor, torch.Tensor] = None + self.slots = [] + self.slots_to_relative_indexes = {} + self.repre_pool: ReprePool | None = repre_pool self.store_instance = store_instance + self.retrieval_worker: Optional[RetrievalWorker] = retrieval_worker + self.retrieval_task = None self.req_meta = req_meta - self.block_size = None + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size self.k_cache = None self.v_cache = None self.rank = rank self.tp_size = tp_size self.tasks: Dict[str, Task] = {} - self.init_window_sz = INIT_WINDOW_SZ - self.local_window_sz = LOCAL_WINDOW_SZ - - @classmethod - def req_state_hash(cls, req_id, layer_name): - return hash((req_id, layer_name)) - - @classmethod - def block_hash(cls, request_id, block_id): - return f"req_{request_id}_blk_{block_id}" - - @classmethod - def task_hash(cls, block_ids, store_type, tensor_type): - return hash((tuple(block_ids), store_type, tensor_type)) - - def update_meta(self, req_meta: ReqMeta, forward_context: ForwardContext): - self.req_meta = req_meta + self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["ESA"] + self.indexes: Optional[NDArray[np.int64]] = None + self.block_hashes = None + self.pre_topk_block_hashes: Dict[int, str] = {} + self.sparse_range: int = 0 + self.init_static_flag = False + + self.num_key_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + self.head_size = vllm_config.model_config.get_head_size() + self.sparse_range = self.get_sparse_prefill_range() - def retrieval(self, query: torch.Tensor, top_k: int): - if top_k >= self.block_repre.shape[0]: - n_blocks = self.block_repre.shape[0] - block_ids = list( - range(self.init_window_sz, n_blocks - self.local_window_sz + 1) + def set_block_hashes(self, token_ids): + if self.block_hashes is not None: + return + self.block_hashes = [] + parent_block_hash_value = None + req_extra_keys = None + for start in range(0, len(token_ids), self.block_size): + end = start + self.block_size + block_token_ids = token_ids[start:end] + if len(block_token_ids) < self.block_size: + break + curr_block_token_ids_tuple = tuple(block_token_ids) + block_hash = block_hash_func( + parent_block_hash_value, curr_block_token_ids_tuple, req_extra_keys ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes - ntokens, num_q_heads, _ = query.shape - if num_q_heads > self.num_key_heads: - query = query.view(ntokens, self.num_key_heads, -1, self.head_size) - query = query.mean(2) - elif num_q_heads < self.num_key_heads: - query = torch.repeat_interleave(query, self.num_key_heads // num_q_heads, 1) + self.block_hashes.append(str(block_hash)) + parent_block_hash_value = block_hash - retrieval_start = self.init_window_sz - retrieval_end = self.block_repre.shape[0] - self.local_window_sz + 1 - block_repre_ = self.block_repre[retrieval_start:retrieval_end] - - scores = torch.einsum("qnd,knd->nqk", query, block_repre_) - scores = scores.softmax(-1) - scores = scores.sum(0).sum(0) - topk_ret = torch.topk(scores, top_k) - topk_index = topk_ret.indices - topk_index = ( - topk_index.sort().values - ) # TODO: remove this, don't need to sort in decode - block_ids = [id.item() + self.init_window_sz for id in topk_index] - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes - - def construct_init_and_local_window(self): - vllm_block_ids = self.req_meta.vllm_block_ids - # TODO: make sure we don't need to clone() - self.init_window = ( - self.k_cache[vllm_block_ids[: self.init_window_sz]], - self.v_cache[vllm_block_ids[: self.init_window_sz]], - ) - local_window_sz = min( - self.local_window_sz, len(vllm_block_ids[self.init_window_sz :]) - ) - if local_window_sz > 0: - self.local_window = ( - self.k_cache[vllm_block_ids[-local_window_sz:]], - self.v_cache[vllm_block_ids[-local_window_sz:]], - ) + def update_meta(self, req_meta: ReqMeta): + self.req_meta = req_meta def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): fn = getattr(self.store_instance, transfer_type) length = len(block_hashes) block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache.untyped_storage().element_size() + precision = self.k_cache.storage().element_size() # TODO: consider is_mla here is_mla = False + + block_shape = tuple(block_shape) offsets_k = [ get_offset( block_shape, @@ -247,55 +252,118 @@ def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): is_mla=is_mla, ) ] * length + key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] + task_k = fn(block_hashes, offsets_k, key_src_tensors) task_v = fn(block_hashes, offsets_v, value_src_tensors) - task_k_hash = self.task_hash(block_hashes, transfer_type, "key") + + task_k_hash = task_hash_func(block_hashes, transfer_type, "key") self.tasks[task_k_hash] = task_k - task_v_hash = self.task_hash(block_hashes, transfer_type, "value") + task_v_hash = task_hash_func(block_hashes, transfer_type, "value") self.tasks[task_v_hash] = task_v def extract_block_repre(self, vllm_block_ids): return self.k_cache[vllm_block_ids].mean(1) - def save_blocks(self, num_blocks_need_dump): - if num_blocks_need_dump <= 0: - return - vllm_block_ids = self.req_meta.vllm_block_ids - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - block_ids = list( - range(num_blocks_dumped, num_blocks_dumped + num_blocks_need_dump) - ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - if self.req_meta.stage == SequenceStage.PREFILL: - vllm_block_ids_dump = vllm_block_ids[ - num_blocks_dumped : num_blocks_dumped + num_blocks_need_dump - ] - else: - # TODO: handle spec_decode here - vllm_block_ids_dump = vllm_block_ids[-1:] - self.launch_transfer_task("dump", block_hashes, vllm_block_ids_dump) - repre = self.extract_block_repre(vllm_block_ids_dump) - # TODO: pre-allocate can speed up here - if self.block_repre is None: - self.block_repre = repre - else: - self.block_repre = torch.cat([self.block_repre, repre]) - - def maybe_register_kv_cache(self, forward_context: ForwardContext): - if self.block_size: + def maybe_register_static_data(self, forward_context: ForwardContext): + if self.init_static_flag: return attn = forward_context.no_compile_layers[self.layer_name] kv_cache = attn.kv_cache[forward_context.virtual_engine] - # TODO: consider is_mla here + # TODO not mla self.k_cache = kv_cache[0] self.v_cache = kv_cache[1] - self.block_size = self.k_cache.shape[1] - self.num_key_heads = self.k_cache.shape[2] - self.head_size = self.k_cache.shape[3] + self.set_block_hashes(self.req_meta.prompt_token_ids) + self.init_static_flag = True + + def wait_transfer_task_done(self): + assert len(self.tasks) > 0 + for task_hash, task in self.tasks.items(): + # TODO: handle exceptions + ret = self.store_instance.wait(task) + self.tasks.clear() # reset + + def start_retrieval(self, batch_query, forward_context): + query_start_loc = self.req_meta.query_start_loc + query_len = self.req_meta.num_scheduled_tokens + query = batch_query[query_start_loc : query_start_loc + query_len] + ntokens, num_q_heads, _ = query.shape + if num_q_heads > self.num_key_heads: + query = query.view(ntokens, self.num_key_heads, -1, self.head_size) + query = query.mean(2) + elif num_q_heads < self.num_key_heads: + query = torch.repeat_interleave(query, self.num_key_heads // num_q_heads, 1) + query_flat = query.reshape(query.shape[0], -1) + top_k = int(self.sparse_range * self.esa_cfg["sparse_ratio"]) + indexes = [self.slots] + self.retrieval_task = self.retrieval_worker.submit( + query_flat, topk=top_k, indexes=indexes + ) + + def wait_retrieval_and_start_load(self): + self.retrieval_worker.wait(self.retrieval_task) + result = self.retrieval_worker.get_result(self.retrieval_task) + choosed_slots = result["indices"][0] + rel_block_ids = [self.slots_to_relative_indexes[int(e)] for e in choosed_slots] + block_hashes = [self.block_hashes[id_] for id_ in rel_block_ids] + top_k = int(self.sparse_range * self.esa_cfg["sparse_ratio"]) + sparse_vllm_block_ids = self.req_meta.vllm_block_ids[:top_k] + + # load delta + diff_vllm_block_ids = set(sparse_vllm_block_ids) + diff_block_hashes = set(block_hashes) + if len(self.pre_topk_block_hashes) == 0: + self.pre_topk_block_hashes = { + blk_id: blk_hash + for (blk_id, blk_hash) in zip(sparse_vllm_block_ids, block_hashes) + } + else: + matched = {} + for k in sparse_vllm_block_ids: + if ( + k in self.pre_topk_block_hashes + and self.pre_topk_block_hashes[k] in diff_block_hashes + ): + matched[k] = self.pre_topk_block_hashes[k] + diff_vllm_block_ids.remove(k) + diff_block_hashes.remove(matched[k]) + self.pre_topk_block_hashes = matched + for diff_blk_id, diff_blk_hash in zip( + diff_vllm_block_ids, diff_block_hashes + ): + self.pre_topk_block_hashes[diff_blk_id] = diff_blk_hash + + self.launch_transfer_task( + "load", list(diff_block_hashes), list(diff_vllm_block_ids) + ) + self.retrieval_task = None + + def get_sparse_prefill_range(self): + if (self.req_meta.num_prompt_tokens % self.block_size) == 0: + sparse_range = ( + self.req_meta.num_prompt_tokens // self.block_size + - self.esa_cfg["local_window_sz"] + ) + else: + sparse_range = math.floor( + self.req_meta.num_prompt_tokens / self.block_size + ) - (self.esa_cfg["local_window_sz"] - 1) + return sparse_range + + def block_repre_data(self): + vllm_block_ids = self.req_meta.vllm_block_ids + vllm_block_ids_dump = vllm_block_ids[: self.sparse_range] + repre = self.extract_block_repre(vllm_block_ids_dump) + repre_flat = repre.reshape(repre.shape[0], -1) + new_slots = self.repre_pool.allocate(self.sparse_range) + og_len = len(self.slots) + for i, slot in enumerate(new_slots): + self.slots_to_relative_indexes[slot] = og_len + i + self.slots.extend(new_slots) + vals = repre_flat.to("cpu", non_blocking=True, dtype=torch.float32) + data[self.layer_id][new_slots] = vals def attention_begin( self, @@ -304,49 +372,12 @@ def attention_begin( value: torch.Tensor, forward_context: ForwardContext, ) -> None: - if self.req_meta.step % RETRIEVAL_STRIDE != 1: - return - index_in_batch = self.req_meta.index_in_batch - if isinstance(forward_context.attn_metadata, dict): - attn_md = forward_context.attn_metadata[self.layer_name] - else: - attn_md = forward_context.attn_metadata - query_start_loc = attn_md.query_start_loc[index_in_batch] - query_len = self.req_meta.num_scheduled_tokens - current_query = query[query_start_loc : query_start_loc + query_len] - - vllm_block_ids = self.req_meta.vllm_block_ids[ - self.init_window_sz : -self.local_window_sz - ] - self.wait_for_task_done() - self.prepare_init_and_local_window() # last dump task(possible) - # NOTE: sync style - topk_block_hashes = self.retrieval(current_query, len(vllm_block_ids)) - self.launch_transfer_task("load", topk_block_hashes, vllm_block_ids) - - self.wait_for_task_done() - - # NOTE: Some sparse attention algorithms need to modify attn_metadata here - - def prepare_init_and_local_window(self): - if self.req_meta.step != 1: - return - - vllm_block_ids = self.req_meta.vllm_block_ids - self.k_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[0] - self.v_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[1] - - if self.local_window is None: - return - - self.k_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[0] - self.v_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[1] - - def wait_for_task_done(self): - for task_hash, task in self.tasks.items(): - # TODO: handle exceptions here, refer to UcmKVConnector - ret = self.store_instance.wait(task) - self.tasks.clear() + self.maybe_register_static_data(forward_context) + if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 1: + if self.req_meta.step == 1: + self.start_retrieval(query, forward_context) + self.wait_retrieval_and_start_load() + self.wait_transfer_task_done() def attention_finished( self, @@ -356,30 +387,61 @@ def attention_finished( attn_output: torch.Tensor, forward_context: ForwardContext, ) -> None: - self.maybe_register_kv_cache(forward_context) - num_tokens_updated = ( - self.req_meta.num_computed_tokens + self.req_meta.num_scheduled_tokens + should_save = ( + self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk ) - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - num_full_blocks = num_tokens_updated // self.block_size - num_blocks_need_dump = num_full_blocks - num_blocks_dumped - self.save_blocks(num_blocks_need_dump) - if self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk: - self.construct_init_and_local_window() - self.wait_for_task_done() + if should_save: + self.block_repre_data() + else: + if self.req_meta.step == 0: + return + if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 2: + self.start_retrieval(query, forward_context) + if self.req_meta.step % self.esa_cfg["retrieval_stride"] == 0: + self.wait_retrieval_and_start_load() class ESA(UcmSparseBase): # handle batch def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): super().__init__(vllm_config, role) - self.req_states: dict[str, ReqStatePerLayer] = {} + self.req_states: dict[str, List[ReqStatePerLayer]] = {} self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size - self.block_size = vllm_config.cache_config.block_size - config = {"max_cache_size": 5368709120, "device": self.rank, "role": "worker"} - self.connector = UcmConnectorFactory.create_connector("UcmDramStore", config) - # TODO: consider init self.is_mla here + self.connector = get_kv_transfer_group().connector + self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["ESA"] + self.total_num_hidden_layers = ( + vllm_config.model_config.hf_config.num_hidden_layers + ) + + global data + + if data is None: + parallel_config = vllm_config.parallel_config + num_slots = ( + vllm_config.model_config.max_model_len + * vllm_config.scheduler_config.max_num_seqs + // vllm_config.cache_config.block_size + ) + dim = ( + vllm_config.model_config.get_num_kv_heads(parallel_config) + * vllm_config.model_config.get_head_size() + ) + data = [ + torch.empty((num_slots, dim), dtype=torch.float32) + for _ in range(self.total_num_hidden_layers) + ] + self.layer_pools: list[ReprePool] = [ + ReprePool(num_slots) for _ in range(self.total_num_hidden_layers) + ] + + self.retrieval_workers: List[RetrievalWorker] = [] + for i in range(self.total_num_hidden_layers): + backend_src = data[i] + backend = retrieval_backend.RetrievalWorkerBackend(backend_src) + self.retrieval_workers.append(RetrievalWorker(backend)) def attention_begin( self, @@ -390,15 +452,25 @@ def attention_begin( forward_context: ForwardContext, ) -> None: for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name - ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector + layer_id = int(layer_name.split(".")[2]) + if req_meta.request_id not in self.req_states: + if self.req_states.get(req_meta.request_id) is None: + self.req_states[req_meta.request_id] = [ + None + ] * self.total_num_hidden_layers + if self.req_states[req_meta.request_id][layer_id] is None: + self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer( + req_meta, + layer_name, + self.rank, + self.tp_size, + self.connector, + self._vllm_config, + self.retrieval_workers[layer_id], + self.layer_pools[layer_id], ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) + req_state = self.req_states[req_meta.request_id][layer_id] + req_state.update_meta(req_meta) req_state.attention_begin(query, key, value, forward_context) def attention_finished( @@ -411,33 +483,31 @@ def attention_finished( forward_context: ForwardContext, ) -> None: for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name - ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector + layer_id = int(layer_name.split(".")[2]) + if req_meta.request_id not in self.req_states: + if self.req_states.get(req_meta.request_id) is None: + self.req_states[req_meta.request_id] = [ + None + ] * self.total_num_hidden_layers + if self.req_states[req_meta.request_id][layer_id] is None: + self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayer( + req_meta, + layer_name, + self.rank, + self.tp_size, + self.connector, + self._vllm_config, + self.retrieval_workers[layer_id], + self.layer_pools[layer_id], ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) + req_state = self.req_states[req_meta.request_id][layer_id] + req_state.update_meta(req_meta) req_state.attention_finished( query, key, value, attn_output, forward_context ) - def wait_all_task_done(self): - pass - - def execute_finished(self): - pass - - def execute_finished(self): - pass - def build_sparse_meta( - self, - scheduler_output, - requests, - input_batch, + self, scheduler_output, requests, input_batch, attn_metadata ) -> UcmSparseMetadata: sparse_meta = ESASparseMetaData() for ( @@ -445,45 +515,86 @@ def build_sparse_meta( num_scheduled_tokens, ) in scheduler_output.num_scheduled_tokens.items(): req_state = requests[req_id] - if len(req_state.prompt_token_ids) > self.block_size: - sparse_meta.add_request( - req_id, - input_batch.req_id_to_index[req_id], - len(req_state.prompt_token_ids), - len(req_state.output_token_ids), - num_scheduled_tokens, - req_state.num_computed_tokens, - scheduler_output.req_sparsed_slots[req_id], - req_state.block_ids[0], - ) + if ( + len(req_state.prompt_token_ids) + <= self._vllm_config.cache_config.block_size + ): + return + + if isinstance(attn_metadata, dict): + attn_metadata = next(iter(attn_metadata.values())) + sparse_meta.add_request( + req_id, + input_batch.req_id_to_index[req_id], + num_scheduled_tokens, + req_state.num_computed_tokens, + req_state.block_ids[0], + attn_metadata.query_start_loc[input_batch.req_id_to_index[req_id]], + req_state.prompt_token_ids, + req_state.output_token_ids, + ) self._sparse_metadata = sparse_meta def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): pass - def request_finished_in_scheduler(self, request_id: ReqType): - pass - def request_finished_in_worker(self, request_id: ReqType): - pass - - def update_state_after_alloc(self, request: Request, num_blocks: int): - pass + for layer_state in self.req_states[request_id]: + layer_state.repre_pool.free(layer_state.slots) + del self.req_states[request_id] def estimate_num_slots_sparsed(self, request: Request) -> int: if ( request.num_output_tokens == 0 - or request.num_prompt_tokens < self.block_size + or request.num_prompt_tokens + < self._vllm_config.cache_config.block_size * self.esa_cfg["min_blocks"] ): return INVALID_SLOT - num_blocks = math.ceil(request.num_tokens / self.block_size) - mid_window_sz = int( - (num_blocks - INIT_WINDOW_SZ - LOCAL_WINDOW_SZ) * SPARSE_RATIO + prompt_len = request.num_prompt_tokens + output_len = request.num_output_tokens + block_size = self._vllm_config.cache_config.block_size + if (flaw := prompt_len % block_size) == 0: + sparse_range = prompt_len // block_size - self.esa_cfg["local_window_sz"] + local_window = block_size * self.esa_cfg["local_window_sz"] + output_len + else: + sparse_range = math.floor(prompt_len / block_size) - ( + self.esa_cfg["local_window_sz"] - 1 + ) + local_window = ( + flaw + block_size * (self.esa_cfg["local_window_sz"] - 1) + output_len + ) + return ( + int(sparse_range * self.esa_cfg["sparse_ratio"]) * block_size + local_window + ) + + def allocate_slots( + self, request, num_slots_sparsed, coordinator, block_pool, kv_cache_groups + ): + block_size = self._vllm_config.cache_config.block_size + num_blocks_need = math.ceil(num_slots_sparsed / block_size) + allocated_blocks = coordinator.get_blocks(request.request_id)[0] + returned_blocks = [] + kept_blocks = [] + num_blocks_original = len(allocated_blocks) + for i, block in enumerate(allocated_blocks): + if i >= num_blocks_original - num_blocks_need: + kept_blocks.append(block) + else: + returned_blocks.append(block) + block_pool._maybe_evict_cached_block(block) + block_pool.free_blocks(returned_blocks) + + coordinator.single_type_managers[0].req_to_blocks[ + request.request_id + ] = kept_blocks + + new_computed_block_list = tuple([] for _ in range(len(kv_cache_groups))) + num_blocks_to_allocate = coordinator.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_slots_sparsed, + new_computed_blocks=new_computed_block_list, ) - flaw = request.num_tokens % self.block_size - if flaw: - flaw = self.block_size - flaw - num_tokens_sparsed = ( - INIT_WINDOW_SZ + mid_window_sz + LOCAL_WINDOW_SZ - ) * self.block_size - flaw - return num_tokens_sparsed + if num_blocks_to_allocate > block_pool.get_num_free_blocks(): + return None + coordinator.allocate_new_blocks(request.request_id, num_slots_sparsed) + return KVCacheBlocks(tuple([kept_blocks])) diff --git a/ucm/ucm_sparse/retrieval/__init__.py b/ucm/ucm_sparse/retrieval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ucm/ucm_sparse/retrieval/retrieval_worker.py b/ucm/ucm_sparse/retrieval/retrieval_worker.py new file mode 100644 index 00000000..db8a398a --- /dev/null +++ b/ucm/ucm_sparse/retrieval/retrieval_worker.py @@ -0,0 +1,93 @@ +import time + +import numpy as np +import torch + +# import retrieval_backend +from ucm.ucm_sparse.retrieval import retrieval_backend + + +class RetrievalWorker: + # handle torch -> numpy && float16/bfloat16 -> float32. + def __init__(self, cpp_worker): + self.cpp_worker = cpp_worker + + @classmethod + def handle_input(cls, input): + if input.dtype != torch.float32: + input = input.to(torch.float32) + input = input.to("cpu", non_blocking=True) + return input + + def submit(self, query, topk, indexes): + q = self.handle_input(query) + req_id = self.cpp_worker.submit(q, topk, indexes) + return req_id + + def poll(self, req_id): + return self.cpp_worker.poll(req_id) # Returns True if ready + + def get_result(self, req_id): + return self.cpp_worker.get_result(req_id) + + def wait(self, req_id): + return self.cpp_worker.wait(req_id) + + +if __name__ == "__main__": + ################# data + batch_size = 2 + dim = 1024 + kv_cache_blocks = 25600 + data = torch.rand(kv_cache_blocks, dim).to(torch.float32) + print("data created", data.shape) + + backend = retrieval_backend.RetrievalWorkerBackend(data) + worker = RetrievalWorker(backend) + topk = 3000 + search_blocks_range = 8000 + tpot = 30 / 1000 + + indexes = np.arange(batch_size * search_blocks_range).reshape( + batch_size, search_blocks_range + ) + + query = torch.rand(batch_size, dim).to(torch.float32) + + #################### cpp async version + req_id = worker.submit(query, topk=topk, indexes=indexes) + + #################### LLM decode begin + time.sleep(tpot * 3) + #################### LLM decode done + + # Poll and get result (in a real program, you'd likely use asyncio or threading) + begin = time.time() + worker.wait(req_id) + result = worker.get_result(req_id) + print("cpp spent:", time.time() - begin) + + ################### numpy version + begin = time.time() + data_indexed = ( + data[indexes.flatten()].reshape(indexes.shape[0], indexes.shape[1], dim).numpy() + ) + query = RetrievalWorker.handle_input(query) + scores = np.matmul(query[:, None, :], data_indexed.transpose((0, 2, 1))) + scores = scores[:, 0, :] + topk_elements = np.partition(scores, -topk, -1)[:, -topk:] + topk_indices = np.argpartition(scores, -topk, -1)[:, -topk:] + topk_indices = indexes[np.arange(indexes.shape[0])[:, None], topk_indices] + print("numpy spent: ", time.time() - begin) + + ## compare + cpp_elements = np.sort(result["scores"], 1) + cpp_indices = np.sort(result["indices"], 1) + + np_elements = np.sort(topk_elements, 1) + np_indices = np.sort(topk_indices, 1) + + diff_elements = np.abs(np_elements - cpp_elements) + diff_indices = np.abs(np_indices - cpp_indices) + + print(f"diff topk: {diff_indices.max()}")