Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions examples/offline_inference_kvcomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import contextlib
import json
import os
import sys
import time
from dataclasses import asdict

from transformers import AutoTokenizer

# Third Party
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs

from ucm.logger import init_logger

logger = init_logger(__name__)
model = ""
path_to_dataset = ""
data_dir = ""
tokenizer = None


def setup_environment_variables():
os.environ["VLLM_USE_V1"] = "1"
os.environ["PYTHONHASHSEED"] = "123456"

global model, path_to_dataset, data_dir, tokenizer
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
if not os.path.isdir(model):
model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
if not os.path.isdir(model):
print("Exiting. Incorrect model_path")
sys.exit(1)

path_to_dataset = os.getenv(
"DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
)
if not os.path.isfile(path_to_dataset):
path_to_dataset = input(
"Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
)
if not os.path.isfile(path_to_dataset):
print("Exiting. Incorrect dataset path")
sys.exit(1)

data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
data_dir = input(
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
)
if not os.path.isdir(data_dir):
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
if create.lower() == "y":
os.makedirs(data_dir, exist_ok=True)
else:
print("Exiting. Directory not created.")
sys.exit(1)

tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)


@contextlib.contextmanager
def build_llm_with_uc(module_path: str, name: str, model: str):
ktc = KVTransferConfig(
kv_connector=name,
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
"ucm_connector_name": "UcmNfsStore",
"ucm_connector_config": {
"storage_backends": data_dir,
"kv_block_size": 33554432,
},
"ucm_sparse_config": {
"KvComp": {
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
"sparse_ratio": 0.3,
"retrieval_stride": 5,
}
},
# "kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_deepseek_v2_lite_config.json",
"kvcomp_config_path": "unified-cache-management/ucm/sparse/kvcomp/configs/kvcomp_qwq_32B_config.json",
},
)

llm_args = EngineArgs(
model=model,
kv_transfer_config=ktc,
max_model_len=32768,
gpu_memory_utilization=0.8,
max_num_batched_tokens=30000,
block_size=128,
enforce_eager=True,
distributed_executor_backend="mp",
tensor_parallel_size=2,
trust_remote_code=True,
)

llm = LLM(**asdict(llm_args))
try:
yield llm
finally:
logger.info("LLM engine is exiting.")


def print_output(
llm: LLM,
prompt: list[str],
sampling_params: SamplingParams,
req_str: str,
):
start = time.time()
outputs = llm.generate(prompt, sampling_params)
print("-" * 50)
for output in outputs:
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
print("-" * 50)


def main():
module_path = "ucm.integration.vllm.uc_connector"
name = "UnifiedCacheConnectorV1"
setup_environment_variables()

def get_prompt(prompt):
messages = [
{
"role": "system",
"content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n",
},
{"role": "user", "content": prompt},
]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=True,
)

with build_llm_with_uc(module_path, name, model) as llm:
prompts = []
batch_size = 10
assert os.path.isfile(
path_to_dataset
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
with open(path_to_dataset, "r") as f:
lines = f.readlines()
for i in range(batch_size):
line = lines[i]
data = json.loads(line)
prompt = f"""阅读以下文字并用中文简短回答:\n\n{data["context"]}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{data["input"]}\n回答:"""
prompts.append(get_prompt(prompt))

sampling_params = SamplingParams(
temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False
)

print_output(llm, prompts, sampling_params, "first")
print_output(llm, prompts, sampling_params, "second")


if __name__ == "__main__":
main()
41 changes: 41 additions & 0 deletions ucm/sparse/esa/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1 +1,42 @@
set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install)
FetchContent_Declare(
numactl
URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz
TLS_VERIFY OFF
)
FetchContent_MakeAvailable(numactl)
if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so")
message(STATUS "Configuring numactl...")
execute_process(
COMMAND ./configure --prefix=${NUMA_INSTALL_DIR}
WORKING_DIRECTORY ${numactl_SOURCE_DIR}
RESULT_VARIABLE numa_configure_result
OUTPUT_VARIABLE numa_configure_output
ERROR_VARIABLE numa_configure_error
)
if(NOT numa_configure_result EQUAL 0)
message(FATAL_ERROR "Failed to configure numactl. \n"
"Result: ${numa_configure_result}\n"
"STDOUT: ${numa_configure_output}\n"
"STDERR: ${numa_configure_error}\n")
endif()

message(STATUS "Building and installing numactl...")
execute_process(
COMMAND make install -j8
WORKING_DIRECTORY ${numactl_SOURCE_DIR}
RESULT_VARIABLE numa_install_result
OUTPUT_VARIABLE numa_install_output
ERROR_VARIABLE numa_install_error
)
if(NOT numa_install_result EQUAL 0)
message(FATAL_ERROR "Failed to build and install numactl. \n"
"Result: ${numa_install_result}\n"
"STDOUT: ${numa_install_output}\n"
"STDERR: ${numa_install_error}\n")
endif()
else()
message(STATUS "Found already built libnuma. Skipping build.")
endif()

add_subdirectory(retrieval)
25 changes: 21 additions & 4 deletions ucm/sparse/esa/esa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import math
import pickle
from collections import defaultdict
from dataclasses import dataclass
from functools import cache
from typing import Dict, List, Optional, Union
Expand All @@ -23,6 +24,7 @@
)
from ucm.sparse.esa.retrieval import retrieval_backend
from ucm.sparse.esa.retrieval.retrieval_worker import RetrievalWorker
from ucm.sparse.kvstar.utils import get_bind_cpus_for_rank
from ucm.store.ucmstore import Task, UcmKVStoreBase

ReqType = Union[str, int]
Expand Down Expand Up @@ -203,9 +205,9 @@ def __init__(
self.rank = rank
self.tp_size = tp_size
self.tasks: Dict[str, Task] = {}
self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config[
"ucm_sparse_config"
]["ESA"]
self.esa_cfg = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"ucm_sparse_config", {}
).get("ESA", None)
self.indexes: Optional[NDArray[np.int64]] = None
self.block_hashes = None
self.pre_topk_block_hashes: Dict[int, str] = {}
Expand Down Expand Up @@ -502,10 +504,25 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
ReprePool(num_slots) for _ in range(self.total_num_hidden_layers)
]

self.local_tp_rank = vllm_config.parallel_config.rank
self.total_tp_size = vllm_config.parallel_config.tensor_parallel_size
ratio = 0.75

bind_info_list, alloc_numa_ids = get_bind_cpus_for_rank(
self.total_tp_size, self.local_tp_rank, ratio=ratio
)

bind_info_dict = defaultdict(list)
for item in bind_info_list:
bind_info_dict[item[1]].append(item[0])
bind_info_dict = dict(bind_info_dict)

self.retrieval_workers: List[RetrievalWorker] = []
for i in range(self.total_num_hidden_layers):
backend_src = data[i]
backend = retrieval_backend.RetrievalWorkerBackend(backend_src)
backend = retrieval_backend.RetrievalWorkerBackend(
backend_src, bind_info_dict
)
self.retrieval_workers.append(RetrievalWorker(backend))

self.preempt_req_output_tokens: Dict[ReqType, int] = {}
Expand Down
15 changes: 15 additions & 0 deletions ucm/sparse/esa/retrieval/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,17 @@
# 添加编译目标
pybind11_add_module(retrieval_backend cpy/retrieval_backend.cpp)

# 设置输出库的目录
set_target_properties(retrieval_backend PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

# 设置头文件目录,以确保 numaf.h 能找到
target_include_directories(retrieval_backend PUBLIC
${NUMA_INSTALL_DIR}/include
${Torch_INCLUDE_DIRS}
)

# 链接所需的库
target_link_libraries(retrieval_backend PUBLIC
${NUMA_INSTALL_DIR}/lib/libnuma.so
${Torch_LIBRARIES}
)
39 changes: 34 additions & 5 deletions ucm/sparse/esa/retrieval/cpy/retrieval_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include <vector>
#include <algorithm>
#include <random>
#include <numaif.h>
#include <iostream>

namespace py = pybind11;

class RetrievalWorkerBackend {
public:
RetrievalWorkerBackend(py::array_t<float> data)
RetrievalWorkerBackend(py::array_t<float> data,
py::dict cpu_idx_tbl)
: data_array_(data), stop_workers_(false), next_req_id_(0)
{
py::buffer_info info = data_array_.request();
Expand All @@ -25,9 +28,35 @@ class RetrievalWorkerBackend {
data_ = static_cast<const float*>(info.ptr);

// Start worker threads
int n_workers = std::thread::hardware_concurrency();
for (int i = 0; i < n_workers; ++i) {
worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this);
for (auto cpu_idx : cpu_idx_tbl) {
int numaId = cpu_idx.first.cast<int>();
py::list core_ids = cpu_idx.second.cast<py::list>();

for (size_t i = 0; i < core_ids.size(); ++i) {
int core_id = core_ids[i].cast<int>();
worker_threads_.emplace_back(&RetrievalWorkerBackend::worker_loop, this);

// 核心绑定代码
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(core_id, &cpuset); // 绑定每个线程到指定的核心

pthread_t thread = worker_threads_.back().native_handle();

// 设置 CPU 亲和性
int rc = pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
if (rc != 0) {
std::cerr << "Error binding thread " << i << " to CPU core " << core_id << std::endl;
}

// 设置内存亲和性
unsigned long nodeMask = 1UL << numaId;
rc = set_mempolicy(MPOL_BIND, &nodeMask, sizeof(nodeMask) * 8);
if (rc != 0) {
std::cerr << "Error binding memory to NUMA node " << numaId << std::endl;
}
}

}
}

Expand Down Expand Up @@ -217,7 +246,7 @@ class RetrievalWorkerBackend {

PYBIND11_MODULE(retrieval_backend, m) {
py::class_<RetrievalWorkerBackend>(m, "RetrievalWorkerBackend")
.def(py::init<py::array_t<float>>())
.def(py::init<py::array_t<float>, py::dict>())
.def("submit", &RetrievalWorkerBackend::submit)
.def("poll", &RetrievalWorkerBackend::poll)
.def("get_result", &RetrievalWorkerBackend::get_result)
Expand Down
42 changes: 42 additions & 0 deletions ucm/sparse/kvcomp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
set(NUMA_INSTALL_DIR ${CMAKE_CURRENT_BINARY_DIR}/numa_install)
FetchContent_Declare(
numactl
URL https://github.com/numactl/numactl/releases/download/v2.0.16/numactl-2.0.16.tar.gz
TLS_VERIFY OFF
)
FetchContent_MakeAvailable(numactl)
if(NOT EXISTS "${NUMA_INSTALL_DIR}/lib/libnuma.so")
message(STATUS "Configuring numactl...")
execute_process(
COMMAND ./configure --prefix=${NUMA_INSTALL_DIR}
WORKING_DIRECTORY ${numactl_SOURCE_DIR}
RESULT_VARIABLE numa_configure_result
OUTPUT_VARIABLE numa_configure_output
ERROR_VARIABLE numa_configure_error
)
if(NOT numa_configure_result EQUAL 0)
message(FATAL_ERROR "Failed to configure numactl. \n"
"Result: ${numa_configure_result}\n"
"STDOUT: ${numa_configure_output}\n"
"STDERR: ${numa_configure_error}\n")
endif()

message(STATUS "Building and installing numactl...")
execute_process(
COMMAND make install -j8
WORKING_DIRECTORY ${numactl_SOURCE_DIR}
RESULT_VARIABLE numa_install_result
OUTPUT_VARIABLE numa_install_output
ERROR_VARIABLE numa_install_error
)
if(NOT numa_install_result EQUAL 0)
message(FATAL_ERROR "Failed to build and install numactl. \n"
"Result: ${numa_install_result}\n"
"STDOUT: ${numa_install_output}\n"
"STDERR: ${numa_install_error}\n")
endif()
else()
message(STATUS "Found already built libnuma. Skipping build.")
endif()

add_subdirectory(hash_retrieval)
Loading