diff --git a/examples/offline_inference_kvcomp.py b/examples/offline_inference_kvcomp.py new file mode 100644 index 00000000..595850be --- /dev/null +++ b/examples/offline_inference_kvcomp.py @@ -0,0 +1,167 @@ +import contextlib +import json +import os +import sys +import time +from dataclasses import asdict + +from transformers import AutoTokenizer + +# Third Party +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs + +from ucm.logger import init_logger + +logger = init_logger(__name__) +model = "" +path_to_dataset = "" +data_dir = "" +tokenizer = None + + +def setup_environment_variables(): + os.environ["VLLM_USE_V1"] = "1" + os.environ["PYTHONHASHSEED"] = "123456" + + global model, path_to_dataset, data_dir, tokenizer + model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct") + if not os.path.isdir(model): + model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ") + if not os.path.isdir(model): + print("Exiting. Incorrect model_path") + sys.exit(1) + + path_to_dataset = os.getenv( + "DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl" + ) + if not os.path.isfile(path_to_dataset): + path_to_dataset = input( + "Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: " + ) + if not os.path.isfile(path_to_dataset): + print("Exiting. Incorrect dataset path") + sys.exit(1) + + data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache") + data_dir = input( + "Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: " + ) + if not os.path.isdir(data_dir): + create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ") + if create.lower() == "y": + os.makedirs(data_dir, exist_ok=True) + else: + print("Exiting. Directory not created.") + sys.exit(1) + + tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True) + + +@contextlib.contextmanager +def build_llm_with_uc(module_path: str, name: str, model: str): + ktc = KVTransferConfig( + kv_connector=name, + kv_connector_module_path=module_path, + kv_role="kv_both", + kv_connector_extra_config={ + "ucm_connector_name": "UcmNfsStore", + "ucm_connector_config": { + "storage_backends": data_dir, + "kv_block_size": 33554432, + }, + "ucm_sparse_config": { + "KvComp": { + "init_window_sz": 1, + "local_window_sz": 2, + "min_blocks": 4, + "sparse_ratio": 0.3, + "retrieval_stride": 5, + } + }, + # "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json", + "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwq_32B_config.json", + }, + ) + + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=32768, + gpu_memory_utilization=0.8, + max_num_batched_tokens=30000, + block_size=128, + enforce_eager=True, + distributed_executor_backend="mp", + tensor_parallel_size=2, + trust_remote_code=True, + ) + + llm = LLM(**asdict(llm_args)) + try: + yield llm + finally: + logger.info("LLM engine is exiting.") + + +def print_output( + llm: LLM, + prompt: list[str], + sampling_params: SamplingParams, + req_str: str, +): + start = time.time() + outputs = llm.generate(prompt, sampling_params) + print("-" * 50) + for output in outputs: + generated_text = output.outputs[0].text + print(f"Generated text: {generated_text!r}") + print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.") + print("-" * 50) + + +def main(): + module_path = "ucm.integration.vllm.uc_connector" + name = "UnifiedCacheConnectorV1" + setup_environment_variables() + + def get_prompt(prompt): + messages = [ + { + "role": "system", + "content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n", + }, + {"role": "user", "content": prompt}, + ] + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=True, + ) + + with build_llm_with_uc(module_path, name, model) as llm: + prompts = [] + batch_size = 10 + assert os.path.isfile( + path_to_dataset + ), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`" + with open(path_to_dataset, "r") as f: + lines = f.readlines() + for i in range(batch_size): + line = lines[i] + data = json.loads(line) + prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:""" + prompts.append(get_prompt(prompt)) + + sampling_params = SamplingParams( + temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False + ) + + print_output(llm, prompts, sampling_params, "first") + print_output(llm, prompts, sampling_params, "second") + + +if __name__ == "__main__": + main() diff --git a/ucm/sparse/esa/CMakeLists.txt b/ucm/sparse/esa/CMakeLists.txt index a7bdc945..f6192b35 100644 --- a/ucm/sparse/esa/CMakeLists.txt +++ b/ucm/sparse/esa/CMakeLists.txt @@ -1 +1,42 @@ +set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) +FetchContent_Declare( + numactl + URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz + TLS_VERIFY OFF +) +FetchContent_MakeAvailable(numactl) +if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") + message(STATUS "Configuring numactl...") + execute_process( + COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_configure_result + OUTPUT_VARIABLE numa_configure_output + ERROR_VARIABLE numa_configure_error + ) + if(NOT numa_configure_result EQUAL 0) + message(FATAL_ERROR "Failed to configure numactl. \n" + "Result: ${numa_configure_result}\n" + "STDOUT: ${numa_configure_output}\n" + "STDERR: ${numa_configure_error}\n") + endif() + + message(STATUS "Building and installing numactl...") + execute_process( + COMMAND make install -j8 + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_install_result + OUTPUT_VARIABLE numa_install_output + ERROR_VARIABLE numa_install_error + ) + if(NOT numa_install_result EQUAL 0) + message(FATAL_ERROR "Failed to build and install numactl. \n" + "Result: ${numa_install_result}\n" + "STDOUT: ${numa_install_output}\n" + "STDERR: ${numa_install_error}\n") + endif() +else() + message(STATUS "Found already built libnuma. Skipping build.") +endif() + add_subdirectory(retrieval) diff --git a/ucm/sparse/esa/esa.py b/ucm/sparse/esa/esa.py index 7f70dd70..5ccaec4a 100644 --- a/ucm/sparse/esa/esa.py +++ b/ucm/sparse/esa/esa.py @@ -1,6 +1,7 @@ import hashlib import math import pickle +from collections import defaultdict from dataclasses import dataclass from functools import cache from typing import Dict, List, Optional, Union @@ -23,6 +24,7 @@ ) from ucm.sparse.esa.retrieval import retrieval_backend from ucm.sparse.esa.retrieval.retrieval_worker import RetrievalWorker +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.store.ucmstore import Task, UcmKVStoreBase ReqType = Union[str, int] @@ -203,9 +205,9 @@ def __init__( self.rank = rank self.tp_size = tp_size self.tasks: Dict[str, Task] = {} - self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_sparse_config" - ]["ESA"] + self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "ucm_sparse_config", {} + ).get("ESA", None) self.indexes: Optional[NDArray[np.int64]] = None self.block_hashes = None self.pre_topk_block_hashes: Dict[int, str] = {} @@ -502,10 +504,25 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): ReprePool(num_slots) for _ in range(self.total_num_hidden_layers) ] + self.local_tp_rank = vllm_config.parallel_config.rank + self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size + ratio = 0.75 + + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + self.total_tp_size, self.local_tp_rank, ratio=ratio + ) + + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) + self.retrieval_workers: List[RetrievalWorker] = [] for i in range(self.total_num_hidden_layers): backend_src = data[i] - backend = retrieval_backend.RetrievalWorkerBackend(backend_src) + backend = retrieval_backend.RetrievalWorkerBackend( + backend_src, bind_info_dict + ) self.retrieval_workers.append(RetrievalWorker(backend)) self.preempt_req_output_tokens: Dict[ReqType, int] = {} diff --git a/ucm/sparse/esa/retrieval/CMakeLists.txt b/ucm/sparse/esa/retrieval/CMakeLists.txt index 3ef08d07..ad05d7bc 100644 --- a/ucm/sparse/esa/retrieval/CMakeLists.txt +++ b/ucm/sparse/esa/retrieval/CMakeLists.txt @@ -1,2 +1,17 @@ +# 添加编译目标 pybind11_add_module(retrieval_backend cpy/retrieval_backend.cpp) + +# 设置输出库的目录 set_target_properties(retrieval_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +# 设置头文件目录,以确保 numaf.h 能找到 +target_include_directories(retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/include + ${Torch_INCLUDE_DIRS} +) + +# 链接所需的库 +target_link_libraries(retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/lib/libnuma.so + ${Torch_LIBRARIES} +) diff --git a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp index 6a81af73..b0f3ea07 100644 --- a/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp +++ b/ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp @@ -11,12 +11,15 @@ #include #include #include +#include +#include namespace py = pybind11; class RetrievalWorkerBackend { public: - RetrievalWorkerBackend(py::array_t data) + RetrievalWorkerBackend(py::array_t data, + py::dict cpu_idx_tbl) : data_array_(data), stop_workers_(false), next_req_id_(0) { py::buffer_info info = data_array_.request(); @@ -25,9 +28,35 @@ class RetrievalWorkerBackend { data_ = static_cast(info.ptr); // Start worker threads - int n_workers = std::thread::hardware_concurrency(); - for (int i = 0; i < n_workers; ++i) { - worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this); + for (auto cpu_idx : cpu_idx_tbl) { + int numaId = cpu_idx.first.cast(); + py::list core_ids = cpu_idx.second.cast(); + + for (size_t i = 0; i < core_ids.size(); ++i) { + int core_id = core_ids[i].cast(); + worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this); + + // 核心绑定代码 + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + + pthread_t thread = worker_threads_.back().native_handle(); + + // 设置 CPU 亲和性 + int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + if (rc != 0) { + std::cerr << "Error binding thread " << i << " to CPU core " << core_id << std::endl; + } + + // 设置内存亲和性 + unsigned long nodeMask = 1UL << numaId; + rc = set_mempolicy(MPOL_BIND, &nodeMask, sizeof(nodeMask) * 8); + if (rc != 0) { + std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; + } + } + } } @@ -217,7 +246,7 @@ class RetrievalWorkerBackend { PYBIND11_MODULE(retrieval_backend, m) { py::class_(m, "RetrievalWorkerBackend") - .def(py::init>()) + .def(py::init, py::dict>()) .def("submit", &RetrievalWorkerBackend::submit) .def("poll", &RetrievalWorkerBackend::poll) .def("get_result", &RetrievalWorkerBackend::get_result) diff --git a/ucm/sparse/kvcomp/CMakeLists.txt b/ucm/sparse/kvcomp/CMakeLists.txt index e69de29b..be24b30b 100644 --- a/ucm/sparse/kvcomp/CMakeLists.txt +++ b/ucm/sparse/kvcomp/CMakeLists.txt @@ -0,0 +1,42 @@ +set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install) +FetchContent_Declare( + numactl + URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz + TLS_VERIFY OFF +) +FetchContent_MakeAvailable(numactl) +if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so") + message(STATUS "Configuring numactl...") + execute_process( + COMMAND ./configure --prefix=${NUMA_INSTALL_DIR} + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_configure_result + OUTPUT_VARIABLE numa_configure_output + ERROR_VARIABLE numa_configure_error + ) + if(NOT numa_configure_result EQUAL 0) + message(FATAL_ERROR "Failed to configure numactl. \n" + "Result: ${numa_configure_result}\n" + "STDOUT: ${numa_configure_output}\n" + "STDERR: ${numa_configure_error}\n") + endif() + + message(STATUS "Building and installing numactl...") + execute_process( + COMMAND make install -j8 + WORKING_DIRECTORY ${numactl_SOURCE_DIR} + RESULT_VARIABLE numa_install_result + OUTPUT_VARIABLE numa_install_output + ERROR_VARIABLE numa_install_error + ) + if(NOT numa_install_result EQUAL 0) + message(FATAL_ERROR "Failed to build and install numactl. \n" + "Result: ${numa_install_result}\n" + "STDOUT: ${numa_install_output}\n" + "STDERR: ${numa_install_error}\n") + endif() +else() + message(STATUS "Found already built libnuma. Skipping build.") +endif() + +add_subdirectory(hash_retrieval) diff --git a/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json b/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json new file mode 100644 index 00000000..0ada4477 --- /dev/null +++ b/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json @@ -0,0 +1,81 @@ +{ + "model_name": "DeepSeek/DeepSeek-V2-Lite-Chat", + "is_mla": true, + "hash_weight_type": "random", + "num_hidden_layers": 27, + "seq_len_threshhold": 2048, + "chunk_size": 128, + "chunk_repre_method": "max", + "head_dim": 576, + "hash_bits": 128, + "top_k_ratio_per_layer": [ + 1, + 1, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 0.3, + 1, + 1, + 1 + ], + "top_k_index_reuse": [ + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1 + ], + "must_select_blocks": [ + 0, + -2, + -1 + ], + "hash_weight": null, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64, + "hash_bits_kv_lora": 512, + "hash_bits_qk_rope": 64, + "hash_weight_kv_lora": null, + "hash_weight_qk_rope": null +} \ No newline at end of file diff --git a/ucm/sparse/kvcomp/hash_encoder.py b/ucm/sparse/kvcomp/hash_encoder.py index 79870454..db76079d 100644 --- a/ucm/sparse/kvcomp/hash_encoder.py +++ b/ucm/sparse/kvcomp/hash_encoder.py @@ -62,13 +62,25 @@ def __init__( logger.warning("automatically using float16 for hash_weights now") self.dtype = torch.float16 - self.hash_weights = torch.normal( + if self.device.type == "cuda" and dtype == torch.bfloat16: + logger.warning("geqrf_cuda not implemented for BFloat16") + logger.warning("automatically using float32 for hash_weights now") + self.dtype = torch.float32 + + # Step 1: 随机高斯矩阵 + random_weights = torch.normal( mean=0, std=2, size=(self.input_dim, self.hash_bits), dtype=self.dtype, device=self.device, ) + # Step 2: QR分解 + Q, R = torch.linalg.qr(random_weights) + + # Step 3: 调整符号,保证Haar 分布 + d = torch.sign(torch.diag(R)) + self.hash_weights = Q * d if self.device.type == "cuda" or self.device.type == "cpu": self._init_bit_masks() diff --git a/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt b/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt new file mode 100644 index 00000000..5f2e8f7f --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/CMakeLists.txt @@ -0,0 +1,17 @@ +# 添加编译目标 +pybind11_add_module(hash_retrieval_backend cpy/hash_retrieval_backend.cpp) + +# 设置输出库的目录 +set_target_properties(hash_retrieval_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + +# 设置头文件目录,以确保 numaf.h 能找到 +target_include_directories(hash_retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/include + ${Torch_INCLUDE_DIRS} +) + +# 链接所需的库 +target_link_libraries(hash_retrieval_backend PUBLIC + ${NUMA_INSTALL_DIR}/lib/libnuma.so + ${Torch_LIBRARIES} +) \ No newline at end of file diff --git a/ucm/sparse/kvcomp/hash_retrieval/__init__.py b/ucm/sparse/kvcomp/hash_retrieval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp new file mode 100644 index 00000000..18bf393a --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/cpy/hash_retrieval_backend.cpp @@ -0,0 +1,364 @@ +// hash_retrieval_backend.cpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // 用于UINT16_MAX +#include +#include +#ifdef __ARM_NEON +#include // ARM NEON SIMD 指令集头文件 +#elif defined(__x86_64__) || defined(_M_X64) +#include // x86_64 SSE SIMD 指令集头文件 +#endif + + +#define VEC_SIZE 16 + +namespace py = pybind11; + +class HashRetrievalWorkerBackend { +public: + HashRetrievalWorkerBackend(py::array_t data, + py::dict cpu_idx_tbl) + : data_array_(data), stop_workers_(false), next_req_id_(0) + { + py::buffer_info info = data_array_.request(); + num_blocks_ = info.shape[0]; + block_size_ = info.shape[1]; + dim_ = info.shape[2]; + vec_per_dim_ = dim_ / VEC_SIZE; // data_每个值类型uint8_t,组成8*16_t进行simd加速 + data_ = static_cast(info.ptr); + + // Start worker threads + for (auto cpu_idx : cpu_idx_tbl) { + int numaId = cpu_idx.first.cast(); + py::list core_ids = cpu_idx.second.cast(); + + for (size_t i = 0; i < core_ids.size(); ++i) { + int core_id = core_ids[i].cast(); + worker_threads_.emplace_back(&HashRetrievalWorkerBackend::worker_loop, this); + + // 核心绑定代码 + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心 + + pthread_t thread = worker_threads_.back().native_handle(); + + // 设置 CPU 亲和性 + int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + if (rc != 0) { + std::cerr << "Error binding thread " << i << " to CPU core " << core_id << std::endl; + } + + // 设置内存亲和性 + unsigned long nodeMask = 1UL << numaId; + rc = set_mempolicy(MPOL_BIND, &nodeMask, sizeof(nodeMask) * 8); + if (rc != 0) { + std::cerr << "Error binding memory to NUMA node " << numaId << std::endl; + } + } + + } + } + + ~HashRetrievalWorkerBackend() { + { + std::lock_guard lock(mutex_); + stop_workers_ = true; + cond_.notify_all(); + } + for (auto& t: worker_threads_) t.join(); + } + + int submit(py::array_t query, int topk, py::array_t indexes) { + py::buffer_info qinfo = query.request(); + py::buffer_info iinfo = indexes.request(); + if (qinfo.shape[1] != dim_) + throw std::runtime_error("Query dim mismatch"); + if ((size_t)iinfo.shape[0] != (size_t)qinfo.shape[0]) + throw std::runtime_error("Query and indexes batch mismatch"); + + int req_id = next_req_id_.fetch_add(1); + + auto q = std::vector((uint8_t*)qinfo.ptr, (uint8_t*)qinfo.ptr + qinfo.shape[0] * dim_); + + // Parse indexes to vector> + size_t n_requests = iinfo.shape[0], max_index_number = iinfo.shape[1]; + const int* idx_ptr = static_cast(iinfo.ptr); + std::vector> idxvec(n_requests); + for (size_t i = 0; i < n_requests; ++i) { + for (size_t j = 0; j < max_index_number; ++j) { + int index = idx_ptr[i * max_index_number + j]; + if (index != -1) idxvec[i].push_back(index); + } + } + + auto status = std::make_shared(); + { + std::lock_guard lock(mutex_); + requests_.emplace(Request{req_id, std::move(q), n_requests, topk, std::move(idxvec)}); + request_status_[req_id] = status; + } + cond_.notify_one(); + return req_id; + } + + bool poll(int req_id) { + std::lock_guard lock(mutex_); + return results_.find(req_id) != results_.end(); + } + + void wait(int req_id) { + std::shared_ptr s; + { + std::lock_guard lock(mutex_); + auto it = request_status_.find(req_id); + if (it == request_status_.end()) throw std::runtime_error("Bad req_id"); + s = it->second; + } + std::unique_lock lk2(s->m); + s->cv.wait(lk2, [&] { return s->done; }); + } + + py::dict get_result(int req_id) { + std::lock_guard lock(mutex_); + auto it = results_.find(req_id); + if (it == results_.end()) throw std::runtime_error("Result not ready"); + + size_t batch_size = it->second.indices.size(); + size_t topk = batch_size > 0 ? it->second.indices[0].size() : 0; + py::array_t indices({batch_size, topk}); + py::array_t scores({batch_size, topk}); + + auto indices_ptr = static_cast(indices.request().ptr); + auto scores_ptr = static_cast(scores.request().ptr); + + for (size_t i = 0; i < batch_size; ++i) { + memcpy(indices_ptr + i * topk, it->second.indices[i].data(), topk * sizeof(int)); + memcpy(scores_ptr + i * topk, it->second.scores[i].data(), topk * sizeof(int)); + } + py::dict result; + result["indices"] = indices; + result["scores"] = scores; + results_.erase(it); + return result; + } + +private: + struct Request { + int req_id; + std::vector query; // Flattened [batch, dim] + size_t batch; + int topk; + std::vector> indexes; // Per-request index subset + }; + struct Result { + std::vector> indices; + std::vector> scores; + }; + + struct RequestStatus { + std::mutex m; + std::condition_variable cv; + bool done = false; + }; + +#ifdef __ARM_NEON + static inline uint16_t vaddvq_u8_compat(uint8x16_t v) { + #if defined(__aarch64__) || defined(_M_ARM64) + return vaddvq_u8(v); + #else + uint16x8_t s16 = vpaddlq_u8(v); + uint32x4_t s32 = vpaddlq_u16(s16); + uint64x2_t s64 = vpaddlq_u32(s32); + return (uint16_t)(vgetq_lane_u64(s64, 0) + vgetq_lane_u64(s64, 1)); + #endif + } + + void print_uint8x16(uint8x16_t vec) { + uint8_t array[16]; + vst1q_u8(array, vec); + for (int i = 0; i < 16; ++i) { + std::cout << static_cast(array[i]) << " "; + } + std::cout << std::endl; + } + +#elif defined(__x86_64__) || defined(_M_X64) + // 采用 Brian Kernighan's 算法计算 64 位数的 Hamming Weight + unsigned int popcnt64(uint64_t x) { + x -= (x >> 1) & 0x5555555555555555; // 将相邻的两位合并 + x = (x & 0x3333333333333333) + ((x >> 2) & 0x3333333333333333); // 合并四位 + x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0F; // 合并八位 + x = x + (x >> 8); // 合并十六位 + x = x + (x >> 16); // 合并三十二位 + x = x + (x >> 32); // 合并六十四位 + return x & 0x7F; // 返回最后的1的个数,0x7F表示最多返回 7 位 + } + + // 计算 128 位向量中 1 的个数 + int popcount_128(__m128i xor_result) { + // 将 128 位数据拆成两个 64 位整数 + uint64_t* result = (uint64_t*)&xor_result; + + // 分别计算每个 64 位的 Hamming 权重并返回结果之和 + return popcnt64(result[0]) + popcnt64(result[1]); + } +#endif + + void worker_loop() { + while (true) { + Request req; + { + std::unique_lock lock(mutex_); + cond_.wait(lock, [&]{ return stop_workers_ || !requests_.empty(); }); + if (stop_workers_ && requests_.empty()) return; + req = std::move(requests_.front()); + requests_.pop(); + } + + Result res; + res.indices.resize(req.batch); + res.scores.resize(req.batch); + + // #pragma omp parallel for schedule(dynamic) + for (size_t b = 0; b < req.batch; ++b) { + const uint8_t* q_ptr = req.query.data() + b * dim_; + const auto& allowed = req.indexes[b]; + std::vector> heap; + heap.reserve(allowed.size()); + + // 1.预加载 query 向量 + #ifdef __ARM_NEON + uint8x16_t q_vecs[vec_per_dim_]; // 存储 query 向量 + for (size_t v = 0; v < vec_per_dim_; ++v) { + q_vecs[v] = vld1q_u8(q_ptr + v * VEC_SIZE); + } + #elif defined(__x86_64__) || defined(_M_X64) + __m128i q_vecs[vec_per_dim_]; // 存储 query 向量 + for (size_t v = 0; v < vec_per_dim_; ++v) { + q_vecs[v] = _mm_loadu_si128(reinterpret_cast(q_ptr + v * VEC_SIZE)); + } + #endif + + // 2.遍历允许的索引 + for (auto idx : allowed) { + const uint8_t* base_idx_ptr = data_ + idx * block_size_ * dim_; + + int score = UINT16_MAX; // 初始化为最大值 + + // 3.内层向量化计算 + // #pragma omp parallel for + for (size_t t_idx = 0; t_idx < block_size_; ++t_idx) { + int sum = 0; + const uint8_t* k_base = base_idx_ptr + t_idx * dim_; + + // 计算每个向量的相似度 + for (size_t v = 0; v < vec_per_dim_; ++v) { + #ifdef __ARM_NEON + uint8x16_t k = vld1q_u8(k_base + v * VEC_SIZE); + sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q_vecs[v], k))); + #elif defined(__x86_64__) || defined(_M_X64) + __m128i k = _mm_loadu_si128(reinterpret_cast(k_base + v * VEC_SIZE)); + __m128i xor_result = _mm_xor_si128(q_vecs[v], k); // 16 * 8 + int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 + sum += popcount_result; // 获取每个字节的累计值 + #endif + } + + // 处理不足16字节的部分 + ssize_t tail_dim = dim_ % VEC_SIZE; + if (tail_dim != 0) { + uint8_t q_tmp[16] = { 0 }; // 初始化填充为0 + uint8_t k_tmp[16] = { 0 }; + memcpy(q_tmp, q_ptr, dim_); + memcpy(k_tmp, k_base, dim_); + + #ifdef __ARM_NEON + uint8x16_t q = vld1q_u8(q_tmp); + uint8x16_t k = vld1q_u8(k_tmp); + sum += vaddvq_u8_compat(vcntq_u8(veorq_u8(q, k))); + #elif defined(__x86_64__) || defined(_M_X64) + __m128i q = _mm_loadu_si128(reinterpret_cast(q_tmp)); + __m128i k = _mm_loadu_si128(reinterpret_cast(k_tmp)); + __m128i xor_result = _mm_xor_si128(q, k); + int popcount_result = popcount_128(xor_result); // 计算128位 xor_result 中所有位为 1 的个数 + sum += popcount_result; // 获取每个字节的累计值 + #endif + } + + // 如果得分为0,则跳出循环 + if (sum < score) { + score = sum; + if (score == 0) { + break; + } + } + } + + // 将结果加入堆中 + heap.emplace_back(score, idx); + } + + // 获取当前TopK + int curr_topk = std::min((int)heap.size(), req.topk); + + // 对堆进行部分排序,获取TopK + std::partial_sort(heap.begin(), heap.begin() + curr_topk, heap.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + + // 保存TopK结果 + for (int k = 0; k < curr_topk; ++k) { + res.scores[b].push_back(heap[k].first); + res.indices[b].push_back(heap[k].second); + } + } + + { + std::lock_guard lock(mutex_); + results_[req.req_id] = std::move(res); + auto s = request_status_[req.req_id]; + { + std::lock_guard lk2(s->m); + s->done = true; + } + s->cv.notify_all(); + } + } + } + + py::array_t data_array_; + const uint8_t* data_ = nullptr; + ssize_t dim_; + size_t num_blocks_, block_size_, vec_per_dim_; + std::queue requests_; + std::unordered_map results_; + std::vector worker_threads_; + std::mutex mutex_; + std::condition_variable cond_; + std::unordered_map> request_status_; + bool stop_workers_; + std::atomic next_req_id_; +}; + +PYBIND11_MODULE(hash_retrieval_backend, m) { + py::class_(m, "HashRetrievalWorkerBackend") + .def(py::init, py::dict>()) + .def("submit", &HashRetrievalWorkerBackend::submit) + .def("poll", &HashRetrievalWorkerBackend::poll) + .def("get_result", &HashRetrievalWorkerBackend::get_result) + .def("wait", &HashRetrievalWorkerBackend::wait); +} \ No newline at end of file diff --git a/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py new file mode 100644 index 00000000..7a77b05a --- /dev/null +++ b/ucm/sparse/kvcomp/hash_retrieval/hash_retrieval_worker.py @@ -0,0 +1,93 @@ +import time + +import numpy as np +import torch + +from ucm.sparse.kvcomp.hash_encoder import HashEncoder +from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend + + +class HashRetrievalWorker: + # handle torch -> numpy && float16/bfloat16 -> float32. + def __init__(self, cpp_worker): + self.cpp_worker = cpp_worker + + @classmethod + def handle_input(cls, input): + if input.dtype != torch.uint8: + input = input.to(torch.uint8) + input = input.to("cpu", non_blocking=True) + return input + + def submit(self, query, topk, indexes): + q = self.handle_input(query) + req_id = self.cpp_worker.submit(q, topk, indexes) + return req_id + + def poll(self, req_id): + return self.cpp_worker.poll(req_id) # Returns True if ready + + def get_result(self, req_id): + return self.cpp_worker.get_result(req_id) + + def wait(self, req_id): + return self.cpp_worker.wait(req_id) + + +if __name__ == "__main__": + ################# data + batch_size = 2 + dim = 1024 + kv_cache_blocks = 25600 + data = torch.rand(kv_cache_blocks, dim).to(torch.float32) + print("data created", data.shape) + + backend = hash_retrieval_backend.HashRetrievalWorkerBackend(data) + worker = HashRetrievalWorker(backend) + topk = 3000 + search_blocks_range = 8000 + tpot = 30 / 1000 + + indexes = np.arange(batch_size * search_blocks_range).reshape( + batch_size, search_blocks_range + ) + + query = torch.rand(batch_size, dim).to(torch.float32) + + #################### cpp async version + req_id = worker.submit(query, topk=topk, indexes=indexes) + + #################### LLM decode begin + time.sleep(tpot * 3) + #################### LLM decode done + + # Poll and get result (in a real program, you'd likely use asyncio or threading) + begin = time.time() + worker.wait(req_id) + result = worker.get_result(req_id) + print("cpp spent:", time.time() - begin) + + ################### numpy version + begin = time.time() + data_indexed = ( + data[indexes.flatten()].reshape(indexes.shape[0], indexes.shape[1], dim).numpy() + ) + query = HashRetrievalWorker.handle_input(query) + scores = np.matmul(query[:, None, :], data_indexed.transpose((0, 2, 1))) + scores = scores[:, 0, :] + topk_elements = np.partition(scores, -topk, -1)[:, -topk:] + topk_indices = np.argpartition(scores, -topk, -1)[:, -topk:] + topk_indices = indexes[np.arange(indexes.shape[0])[:, None], topk_indices] + print("numpy spent: ", time.time() - begin) + + ## compare + cpp_elements = np.sort(result["scores"], 1) + cpp_indices = np.sort(result["indices"], 1) + + np_elements = np.sort(topk_elements, 1) + np_indices = np.sort(topk_indices, 1) + + diff_elements = np.abs(np_elements - cpp_elements) + diff_indices = np.abs(np_indices - cpp_indices) + + print(f"diff topk: {diff_indices.max()}") diff --git a/ucm/sparse/kvcomp/kvcomp.py b/ucm/sparse/kvcomp/kvcomp.py index c1884b2e..c1713300 100644 --- a/ucm/sparse/kvcomp/kvcomp.py +++ b/ucm/sparse/kvcomp/kvcomp.py @@ -1,587 +1,292 @@ -""" -The MIT License - -Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - import math -import time +from collections import defaultdict from dataclasses import dataclass -from functools import wraps -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union +import numpy as np import torch from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import get_kv_transfer_group from vllm.forward_context import ForwardContext -from vllm.sequence import SequenceStage -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus from ucm.logger import init_logger -from ucm.sandbox.sparse.kvcomp.hash_encoder import HashEncoder -from ucm.sandbox.sparse.kvcomp.kvcomp_config import KvCompConfig -from ucm.sparse.state import get_ucm_sparse - -logger = init_logger(__name__) - from ucm.sparse.base import ( INVALID_SLOT, UcmSparseBase, - UcmSparseMetadata, UcmSparseRole, ) +from ucm.sparse.esa.esa import ( + ESA, + ESASparseMetaData, + ReprePool, + ReqStatePerLayer, + get_sparse_range, +) +from ucm.sparse.kvcomp.hash_encoder import HashEncoder +from ucm.sparse.kvcomp.hash_retrieval import hash_retrieval_backend +from ucm.sparse.kvcomp.hash_retrieval.hash_retrieval_worker import HashRetrievalWorker +from ucm.sparse.kvcomp.kvcomp_config import KvCompConfig +from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank from ucm.sparse.state import get_ucm_sparse -from ucm.store.factory import UcmConnectorFactory from ucm.store.ucmstore import Task, UcmKVStoreBase - -def stat(func): - @wraps(func) - def wrapper(*args, **kwargs): - wrapper.call_count += 1 - start = time.perf_counter_ns() - result = func(*args, **kwargs) - end = time.perf_counter_ns() - cost = end - start - wrapper.time_costs.append(cost) - return result - - wrapper.call_count = 0 - wrapper.time_costs = [] - return wrapper - +logger = init_logger(__name__) ReqType = Union[str, int] -HashType = Union[str, int] - -# TODO: add ESA specific config in kv_transfer_config -> extra_config -INIT_WINDOW_SZ = 1 -LOCAL_WINDOW_SZ = 2 -SPARSE_RATIO = 0.3 -RETRIEVAL_STRIDE = 4 - - -@dataclass -class ReqMeta: - request_id: ReqType - index_in_batch: int - num_prompt_tokens: int - num_output_tokens: int - num_scheduled_tokens: int - num_computed_tokens: int - num_sparsed_tokens: int - vllm_block_ids: list[int] - - @property - def step(self) -> int: - return self.num_output_tokens - - @property - def stage(self) -> SequenceStage: - return ( - SequenceStage.DECODE - if self.num_output_tokens > 0 - else SequenceStage.PREFILL - ) - @property - def is_last_chunk(self) -> bool: - return ( - self.num_computed_tokens + self.num_scheduled_tokens - >= self.num_prompt_tokens - ) +data = None -@dataclass -class KvCompSparseMetaData(UcmSparseMetadata): - requests: list[ReqMeta] - finished_req_ids: List[ReqType] - - def __init__(self): - self.requests = [] - self.finished_req_ids = [] - - def add_request( - self, - request_id: ReqType, - index_in_batch: int, - num_prompt_tokens: int, - num_output_tokens: int, - num_scheduled_tokens: int, - num_computed_tokens: int, - num_sparsed_tokens: int, - vllm_block_ids: list[int], - ) -> None: - meta = ReqMeta( - request_id=request_id, - index_in_batch=index_in_batch, - num_prompt_tokens=num_prompt_tokens, - num_output_tokens=num_output_tokens, - num_scheduled_tokens=num_scheduled_tokens, - num_computed_tokens=num_computed_tokens, - num_sparsed_tokens=num_sparsed_tokens, - vllm_block_ids=vllm_block_ids, - ) - self.requests.append(meta) - - -def get_offset(block_shape, rank, tp_size, precision, layer_id, is_v, is_mla) -> int: - block_size, num_key_heads_per_tp, head_size = block_shape - k_min_data_block_size = block_size * num_key_heads_per_tp * head_size * precision - v_min_data_block_size = k_min_data_block_size if not is_mla else 0 - layer_size = (k_min_data_block_size + v_min_data_block_size) * tp_size - if is_mla: - k_offset = layer_size * layer_id - else: - k_offset = layer_size * layer_id + layer_size // tp_size * rank - v_offset = k_offset + k_min_data_block_size - return v_offset if is_v else k_offset - - -class ReqStatePerLayer: +class ReqStatePerLayerKvComp(ReqStatePerLayer): # handle single request per layer def __init__( self, - req_meta: ReqMeta, layer_name: str, rank: int, tp_size: int, store_instance: UcmKVStoreBase, + vllm_config: VllmConfig, + retrieval_worker: Optional[HashRetrievalWorker] = None, + repre_pool: Optional[ReprePool] = None, ): - self.layer_name = layer_name - self.layer_id = int(layer_name.split(".")[2]) - self.block_repre: torch.Tensor = ( - None ## shape: blks, num_key_heads_per_tp, head_size + super().__init__( + layer_name, + rank, + tp_size, + store_instance, + vllm_config, + retrieval_worker, + repre_pool, ) - self.init_window: tuple[torch.Tensor, torch.Tensor] = None - self.local_window: tuple[torch.Tensor, torch.Tensor] = None - self.store_instance = store_instance - self.req_meta = req_meta - self.block_size = None - self.k_cache = None - self.v_cache = None - self.rank = rank - self.tp_size = tp_size - self.tasks: Dict[str, Task] = {} - self.init_window_sz = INIT_WINDOW_SZ - self.local_window_sz = LOCAL_WINDOW_SZ - - @classmethod - def req_state_hash(cls, req_id, layer_name): - return hash((req_id, layer_name)) - - @classmethod - def block_hash(cls, request_id, block_id): - return f"req_{request_id}_blk_{block_id}" - - @classmethod - def task_hash(cls, block_ids, store_type, tensor_type): - return hash((tuple(block_ids), store_type, tensor_type)) - - def update_meta(self, req_meta: ReqMeta, forward_context: ForwardContext): - self.req_meta = req_meta - - def retrieval(self, query: torch.Tensor, top_k: int): - if top_k >= self.block_repre.shape[0]: - n_blocks = self.block_repre.shape[0] - block_ids = list( - range(self.init_window_sz, n_blocks - self.local_window_sz + 1) - ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes + + self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["KvComp"] + # `retrieval_worker` 类型是 HashRetrievalWorker + self.retrieval_worker = retrieval_worker + + def extract_block_repre(self, vllm_block_ids): + ucm_sparse = get_ucm_sparse() + hash_encoder = ucm_sparse.hash_encoder + hashk_cache = hash_encoder.compute_hash(self.k_cache[vllm_block_ids]) + if self.is_mla: + hashk_cache = hashk_cache.unsqueeze(-2) + return hashk_cache + + def start_retrieval(self, batch_query, forward_context): + query_start_loc = self.req_meta.query_start_loc + query_len = self.req_meta.num_scheduled_tokens + query = batch_query[query_start_loc : query_start_loc + query_len] ntokens, num_q_heads, _ = query.shape if num_q_heads > self.num_key_heads: query = query.view(ntokens, self.num_key_heads, -1, self.head_size) query = query.mean(2) elif num_q_heads < self.num_key_heads: query = torch.repeat_interleave(query, self.num_key_heads // num_q_heads, 1) - - retrieval_start = self.init_window_sz - retrieval_end = self.block_repre.shape[0] - self.local_window_sz + 1 - block_repre_ = self.block_repre[retrieval_start:retrieval_end] - - if block_repre_.shape[0] == 0: - scores = torch.empty( - (block_repre_.shape[0]), dtype=query.dtype, device=query.device - ) - else: - ucm_sparse = get_ucm_sparse() - hash_encoder = ucm_sparse.hash_encoder - # query.shape [ntokens/BS, num_heads, head_size] - - # hash_query.shape [ntokens/BS, num_heads, hash_bits//8] - hash_query = hash_encoder.compute_hash(query) - # unpack_hash_query.shape [ntokens/BS, num_heads, hash_bits//8, 8] - unpack_hash_query = hash_encoder._unpack_hash(hash_query) - - # block_repre_.shape [n_blocks, block_size, num_kv_heads, head_size] - # unpack_hash_key_cache.shape [n_blocks, block_size, num_kv_heads, hash_bits//8, 8] - unpack_hash_key_cache = hash_encoder._unpack_hash(block_repre_) - - scores = torch.einsum( - "tid,njsd->tijsn", unpack_hash_query, unpack_hash_key_cache - ) - dims = tuple(range(scores.ndim - 1)) - - # [ntokens/BS, n_blocks] - scores = scores.sum(dim=dims) - - topk_ret = torch.topk(scores, top_k) - topk_index = topk_ret.indices - topk_index = ( - topk_index.sort().values - ) # TODO: remove this, don't need to sort in decode - block_ids = [id.item() + self.init_window_sz for id in topk_index] - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids - ] - return block_hashes - - def construct_init_and_local_window(self): - vllm_block_ids = self.req_meta.vllm_block_ids - # TODO: make sure we don't need to clone() - self.init_window = ( - self.k_cache[vllm_block_ids[: self.init_window_sz]], - self.v_cache[vllm_block_ids[: self.init_window_sz]], - ) - local_window_sz = min( - self.local_window_sz, len(vllm_block_ids[self.init_window_sz :]) - ) - if local_window_sz > 0: - self.local_window = ( - self.k_cache[vllm_block_ids[-local_window_sz:]], - self.v_cache[vllm_block_ids[-local_window_sz:]], - ) - - def launch_transfer_task(self, transfer_type, block_hashes, vllm_block_ids): - fn = getattr(self.store_instance, transfer_type) - length = len(block_hashes) - block_shape = (self.block_size, self.num_key_heads, self.head_size) - precision = self.k_cache.untyped_storage().element_size() - # TODO: consider is_mla here - is_mla = False - offsets_k = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=False, - is_mla=is_mla, - ) - ] * length - offsets_v = [ - get_offset( - block_shape, - self.rank, - self.tp_size, - precision, - self.layer_id, - is_v=True, - is_mla=is_mla, - ) - ] * length - key_src_tensors = [self.k_cache[id_] for id_ in vllm_block_ids] - value_src_tensors = [self.v_cache[id_] for id_ in vllm_block_ids] - task_k = fn(block_hashes, offsets_k, key_src_tensors) - task_v = fn(block_hashes, offsets_v, value_src_tensors) - task_k_hash = self.task_hash(block_hashes, transfer_type, "key") - self.tasks[task_k_hash] = task_k - task_v_hash = self.task_hash(block_hashes, transfer_type, "value") - self.tasks[task_v_hash] = task_v - - def extract_block_repre(self, vllm_block_ids): ucm_sparse = get_ucm_sparse() hash_encoder = ucm_sparse.hash_encoder - hashk_cache = hash_encoder.compute_hash(self.k_cache[vllm_block_ids]) - return hashk_cache + hash_query = hash_encoder.compute_hash(query) + query_flat = hash_query.reshape(query.shape[0], -1) + top_k = int(self.sparse_range * self.esa_cfg["sparse_ratio"]) + indexes = [self.slots] + self.retrieval_task = self.retrieval_worker.submit( + query_flat, topk=top_k, indexes=indexes + ) - def save_blocks(self, num_blocks_need_dump): - if num_blocks_need_dump <= 0: - return - vllm_block_ids = self.req_meta.vllm_block_ids - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - block_ids = list( - range(num_blocks_dumped, num_blocks_dumped + num_blocks_need_dump) + def block_repre_data(self): + self.sparse_range = get_sparse_range( + self.esa_cfg["init_window_sz"], + self.esa_cfg["local_window_sz"], + self.req_meta.num_prompt_tokens, + self.block_size, ) - block_hashes = [ - f"{self.block_hash(self.req_meta.request_id, id_)}" for id_ in block_ids + vllm_block_ids = self.req_meta.vllm_block_ids + # torch.save({"k": self.k_cache[vllm_block_ids].cpu(), "v": self.v_cache[vllm_block_ids].cpu()}, + # f"/home/heke/debug/{self.layer_id}.pkl") + vllm_block_ids_dump = vllm_block_ids[ + self.esa_cfg["init_window_sz"] : self.esa_cfg["init_window_sz"] + + self.sparse_range ] - if self.req_meta.stage == SequenceStage.PREFILL: - vllm_block_ids_dump = vllm_block_ids[ - num_blocks_dumped : num_blocks_dumped + num_blocks_need_dump - ] - else: - # TODO: handle spec_decode here - vllm_block_ids_dump = vllm_block_ids[-1:] - self.launch_transfer_task("dump", block_hashes, vllm_block_ids_dump) - repre = self.extract_block_repre(vllm_block_ids_dump) - # [n_blocks, num_kv_heads, block_size, hash_bits//8] - repre = repre.transpose(1, 2).contiguous() - # TODO: pre-allocate can speed up here - if self.block_repre is None: - self.block_repre = repre - else: - self.block_repre = torch.cat([self.block_repre, repre], dim=0) - - def maybe_register_kv_cache(self, forward_context: ForwardContext): - if self.block_size: - return - attn = forward_context.no_compile_layers[self.layer_name] - kv_cache = attn.kv_cache[forward_context.virtual_engine] - # TODO: consider is_mla here - self.k_cache = kv_cache[0] - self.v_cache = kv_cache[1] - self.block_size = self.k_cache.shape[1] - self.num_key_heads = self.k_cache.shape[2] - self.head_size = self.k_cache.shape[3] - - def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - forward_context: ForwardContext, - ) -> None: - if self.req_meta.step % RETRIEVAL_STRIDE != 1: - return - index_in_batch = self.req_meta.index_in_batch - if isinstance(forward_context.attn_metadata, dict): - attn_md = forward_context.attn_metadata[self.layer_name] + ######## 修改表征 + repre = self.extract_block_repre(vllm_block_ids_dump) + repre_flat = repre.reshape(repre.shape[0], repre.shape[1], -1) + new_slots = self.repre_pool.allocate(self.sparse_range) + og_len = len(self.slots) + for i, slot in enumerate(new_slots): + self.slots_to_relative_indexes[slot] = og_len + i + self.slots.extend(new_slots) + vals = repre_flat.to("cpu", non_blocking=True, dtype=torch.uint8) + data[self.layer_id][new_slots] = vals + ############## + + # NOTE: in Preemption, local_window_start != -self.esa_cfg['local_window_sz'] + local_window_start = self.esa_cfg["init_window_sz"] + self.sparse_range + + if not self.is_mla: + self.init_window = ( + self.k_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + self.v_cache[vllm_block_ids[: self.esa_cfg["init_window_sz"]]].clone(), + ) + self.local_window = ( + self.k_cache[vllm_block_ids[local_window_start:]].clone(), + self.v_cache[vllm_block_ids[local_window_start:]].clone(), + ) else: - attn_md = forward_context.attn_metadata - query_start_loc = attn_md.query_start_loc[index_in_batch] - query_len = self.req_meta.num_scheduled_tokens - current_query = query[query_start_loc : query_start_loc + query_len] - - vllm_block_ids = self.req_meta.vllm_block_ids[ - self.init_window_sz : -self.local_window_sz - ] - self.wait_for_task_done() - self.prepare_init_and_local_window() # last dump task(possible) - # NOTE: sync style - topk_block_hashes = self.retrieval(current_query, len(vllm_block_ids)) - self.launch_transfer_task("load", topk_block_hashes, vllm_block_ids) - - self.wait_for_task_done() + self.init_window = self.k_cache[ + vllm_block_ids[: self.esa_cfg["init_window_sz"]] + ].clone() + self.local_window = self.k_cache[ + vllm_block_ids[local_window_start:] + ].clone() - # NOTE: Some sparse attention algorithms need to modify attn_metadata here - def prepare_init_and_local_window(self): - if self.req_meta.step != 1: - return - - vllm_block_ids = self.req_meta.vllm_block_ids - self.k_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[0] - self.v_cache[vllm_block_ids[: self.init_window_sz]] = self.init_window[1] - - if self.local_window is None: - return - - self.k_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[0] - self.v_cache[vllm_block_ids[-self.local_window_sz :]] = self.local_window[1] - - def wait_for_task_done(self): - for task_hash, task in self.tasks.items(): - # TODO: handle exceptions here, refer to UcmKVConnector - ret = self.store_instance.wait(task) - self.tasks.clear() - - def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - forward_context: ForwardContext, - ) -> None: - self.maybe_register_kv_cache(forward_context) - num_tokens_updated = ( - self.req_meta.num_computed_tokens + self.req_meta.num_scheduled_tokens - ) - num_blocks_dumped = 0 if self.block_repre is None else self.block_repre.shape[0] - num_full_blocks = num_tokens_updated // self.block_size - num_blocks_need_dump = num_full_blocks - num_blocks_dumped - self.save_blocks(num_blocks_need_dump) - if self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk: - self.construct_init_and_local_window() - self.wait_for_task_done() - - -class KvComp(UcmSparseBase): +class KvComp(ESA): # handle batch def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole): - super().__init__(vllm_config, role) - self.req_states: dict[str, ReqStatePerLayer] = {} + UcmSparseBase.__init__(self, vllm_config, role) + self.req_states: dict[str, List[ReqStatePerLayerKvComp]] = {} self.rank = vllm_config.parallel_config.rank self.tp_size = vllm_config.parallel_config.tensor_parallel_size - self.block_size = vllm_config.cache_config.block_size + if role == UcmSparseRole.WORKER: + self.connector = get_kv_transfer_group().connector + else: + self.connector = None + self.total_num_hidden_layers = ( + vllm_config.model_config.hf_config.num_hidden_layers + ) + self.is_mla = vllm_config.model_config.is_deepseek_mla + self._sparse_metadata_prefill: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata_decode: ESASparseMetaData = ESASparseMetaData() + self._sparse_metadata: ESASparseMetaData = ESASparseMetaData() + self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[ + "ucm_sparse_config" + ]["KvComp"] - max_cache_size = vllm_config.kv_transfer_config.kv_connector_extra_config[ - "ucm_connector_config" - ]["max_cache_size"] - config = { - "max_cache_size": max_cache_size, - "device": self.rank, - "role": "worker", - } - self.connector = UcmConnectorFactory.create_connector("UcmDramStore", config) + self.block_size = vllm_config.cache_config.block_size + self.num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + self.hashk_cache = None kvcomp_config_path = vllm_config.kv_transfer_config.kv_connector_extra_config[ "kvcomp_config_path" ] + self.kvcomp_config = KvCompConfig.from_json(kvcomp_config_path) - logger.info(f"read kvcomp config file: {kvcomp_config_path} ") + logger.info(f"read kvcomp config file : {kvcomp_config_path} ") assert ( - self.kvcomp_config.num_hidden_layers - == vllm_config.model_config.hf_text_config.num_hidden_layers + self.kvcomp_config.num_hidden_layers == self.total_num_hidden_layers ), f"kvcomp_config.num_hidden_layers {self.kvcomp_config.num_hidden_layers} \ - != vllm_config.model_config.hf_text_config.num_hidden_layers \ - {vllm_config.model_config.hf_text_config.num_hidden_layers}" - - dtype = vllm_config.model_config.dtype + != vllm_config.model_config.hf_text_config.num_hidden_layers \ + {self.total_num_hidden_layers}" if hasattr(torch, "npu") and torch.npu.is_available(): device = torch.device(f"npu:{self.rank}") elif torch.cuda.is_available(): device = torch.device(f"cuda:{self.rank}") else: - device = torch.device("cpu") + device = torch.device("npu") + self.hash_encoder = HashEncoder( input_dim=self.kvcomp_config.head_dim, hash_bits=self.kvcomp_config.hash_bits, - dtype=dtype, + dtype=vllm_config.model_config.dtype, device=device, ) + self.device = device - # TODO: consider init self.is_mla here - - def attention_begin( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, - ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name - ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector - ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) - req_state.attention_begin(query, key, value, forward_context) + global data - def attention_finished( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_output: torch.Tensor, - layer_name: str, - forward_context: ForwardContext, - ) -> None: - for req_meta in self._sparse_metadata.requests: - req_state_hash = ReqStatePerLayer.req_state_hash( - req_meta.request_id, layer_name + if data is None: + parallel_config = vllm_config.parallel_config + num_slots = ( + vllm_config.model_config.max_model_len + * vllm_config.scheduler_config.max_num_seqs + // vllm_config.cache_config.block_size ) - if req_state_hash not in self.req_states: - self.req_states[req_state_hash] = ReqStatePerLayer( - req_meta, layer_name, self.rank, self.tp_size, self.connector - ) - req_state = self.req_states[req_state_hash] - req_state.update_meta(req_meta, forward_context) - req_state.attention_finished( - query, key, value, attn_output, forward_context + block_size = vllm_config.cache_config.block_size + dim = ( + vllm_config.model_config.get_num_kv_heads(parallel_config) + * self.kvcomp_config.hash_bits # 修改vllm_config.model_config.get_head_size()为hash_bits + // 8 ) + data = [ + torch.empty((num_slots, block_size, dim), dtype=torch.uint8) + for _ in range(self.total_num_hidden_layers) + ] + self.layer_pools: list[ReprePool] = [ + ReprePool(num_slots) for _ in range(self.total_num_hidden_layers) + ] - def wait_all_task_done(self): - pass - - def execute_finished(self): - pass - - def execute_finished(self): - pass + self.local_tp_rank = vllm_config.parallel_config.rank + self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size + ratio = 0.75 - def build_sparse_meta( - self, - scheduler_output, - requests, - input_batch, - ) -> UcmSparseMetadata: - sparse_meta = KvCompSparseMetaData() - for ( - req_id, - num_scheduled_tokens, - ) in scheduler_output.num_scheduled_tokens.items(): - req_state = requests[req_id] - if len(req_state.prompt_token_ids) > self.block_size: - sparse_meta.add_request( - req_id, - input_batch.req_id_to_index[req_id], - len(req_state.prompt_token_ids), - len(req_state.output_token_ids), - num_scheduled_tokens, - req_state.num_computed_tokens, - scheduler_output.req_sparsed_slots[req_id], - req_state.block_ids[0], - ) - self._sparse_metadata = sparse_meta + bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank( + self.total_tp_size, self.local_tp_rank, ratio=ratio + ) - def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]): - pass + bind_info_dict = defaultdict(list) + for item in bind_info_list: + bind_info_dict[item[1]].append(item[0]) + bind_info_dict = dict(bind_info_dict) - def request_finished_in_scheduler(self, request_id: ReqType): - pass + self.retrieval_workers: List[HashRetrievalWorker] = [] + for i in range(self.total_num_hidden_layers): + backend_src = data[i] + backend = hash_retrieval_backend.HashRetrievalWorkerBackend( + backend_src, bind_info_dict + ) + self.retrieval_workers.append(HashRetrievalWorker(backend)) - def request_finished_in_worker(self, request_id: ReqType): - pass + self.preempt_req_output_tokens: Dict[ReqType, int] = {} - def update_state_after_alloc(self, request: Request, num_blocks: int): - pass + def get_or_create_layerwise_req_state(self, req_meta, layer_name): + layer_id = int(layer_name.split(".")[2]) + if req_meta.is_preempt: + print( + f"preempt {req_meta.request_id}, layer_id: {layer_id}, {req_meta.num_output_tokens}" + ) + layer_state = self.req_states[req_meta.request_id][layer_id] + layer_state.repre_pool.free(layer_state.slots) + self.req_states[req_meta.request_id][layer_id] = None + if req_meta.request_id not in self.req_states: + if self.req_states.get(req_meta.request_id) is None: + self.req_states[req_meta.request_id] = [ + None + ] * self.total_num_hidden_layers + if self.req_states[req_meta.request_id][layer_id] is None: + self.req_states[req_meta.request_id][layer_id] = ReqStatePerLayerKvComp( + layer_name, + self.rank, + self.tp_size, + self.connector, + self._vllm_config, + self.retrieval_workers[layer_id], + self.layer_pools[layer_id], + ) + return self.req_states[req_meta.request_id][layer_id] - def estimate_num_slots_sparsed(self, request: Request) -> int: - if ( - request.num_output_tokens == 0 - or request.num_prompt_tokens < self.block_size - ): - return INVALID_SLOT - num_blocks = math.ceil(request.num_tokens / self.block_size) - mid_window_sz = int( - (num_blocks - INIT_WINDOW_SZ - LOCAL_WINDOW_SZ) * SPARSE_RATIO - ) - flaw = request.num_tokens % self.block_size - if flaw: - flaw = self.block_size - flaw - num_tokens_sparsed = ( - INIT_WINDOW_SZ + mid_window_sz + LOCAL_WINDOW_SZ - ) * self.block_size - flaw - return num_tokens_sparsed + def execute_begin(self, scheduler_output): + if self.hashk_cache is None: + print( + " ========================== initialize hashk cache ========================== " + ) + num_blocks = self._vllm_config.cache_config.num_gpu_blocks + self.hashk_cache = [ + torch.empty( + ( + num_blocks, + self.num_kv_heads, + self.block_size, + self.hash_encoder.hash_bits // 8, + ), + dtype=torch.uint8, + device=self.device, + ) + for _ in range(self.total_num_hidden_layers) + ]