Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import contextlib
import json
import os
import time
from dataclasses import asdict

from transformers import AutoTokenizer

# Third Party
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.engine.arg_utils import EngineArgs

from ucm.logger import init_logger

MODEL_PATH = "/home/models/Qwen2.5-14B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
logger = init_logger(__name__)


Expand All @@ -25,21 +30,30 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
kv_connector_module_path=module_path,
kv_role="kv_both",
kv_connector_extra_config={
"ucm_connector_name": "UcmDramStore",
"ucm_connector_name": "UcmNfsStore",
"ucm_connector_config": {
"max_cache_size": 53687091200,
"kv_block_size": 262144,
"storage_backends": "/home/data",
"kv_block_size": 33554432,
},
"ucm_sparse_config": {
"ESA": {
"init_window_sz": 1,
"local_window_sz": 2,
"min_blocks": 4,
"sparse_ratio": 0.3,
"retrieval_stride": 5,
}
},
"ucm_sparse_method": "GSA",
},
)

llm_args = EngineArgs(
model=model,
kv_transfer_config=ktc,
max_model_len=40960,
gpu_memory_utilization=0.87,
max_model_len=32768,
gpu_memory_utilization=0.8,
max_num_batched_tokens=30000,
block_size=128,
enforce_eager=True,
)

llm = LLM(**asdict(llm_args))
Expand Down Expand Up @@ -72,17 +86,35 @@ def main():

setup_environment_variables()

with build_llm_with_uc(module_path, name, model) as llm:
prompts = [
"Imagine you are an artificial intelligence developed in the year 2075, designed to assist humanity in "
"navigating the complex ethical, philosophical, and technological challenges of a rapidly evolving world. "
"You have access to vast historical records, scientific data, and human literature, and your core "
"directive is to promote sustainable development, social equity, and the flourishing of conscious beings. "
"Write a detailed letter to the leaders of Earth, explaining the most urgent global issue of the 21st "
"century, the root sauses behind it, and a set of scientifically grounded, morally sound, and globally "
"cooperative solutions that transcend culturak and national boundaries. Include both immediate actions "
"and long-term strategies." * 200
def get_prompt(prompt):
messages = [
{
"role": "system",
"content": "先读问题,再根据下面的文章内容回答问题,不要进行分析,不要重复问题,用简短的语句给出答案。\n\n例如:“全国美国文学研究会的第十八届年会在哪所大学举办的?”\n回答应该为:“xx大学”。\n\n",
},
{"role": "user", "content": prompt},
]
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=True,
)

with build_llm_with_uc(module_path, name, model) as llm:
prompts = []

batch_size = 1

with open("/home/datasets/Longbench/data/multifieldqa_zh.jsonl", "r") as f:
for _ in range(batch_size):
line = f.readline()
if not line:
break
data = json.loads(line)
context = data["context"]
question = data["input"]
prompts.append(get_prompt(f"{context}\n\n{question}"))

sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100)

Expand Down
15 changes: 11 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
STORE_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "store")
GSA_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "gsaoffloadops")
PREFETCH_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "ucmprefetch")
RETRIEVAL_SRC_DIR = os.path.join(ROOT_DIR, "ucm", "csrc", "esaretrieval")

STORE_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "store", "connector")
GSA_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "ucm_sparse")
RETRIEVAL_INSTALL_DIR = os.path.join(ROOT_DIR, "ucm", "ucm_sparse", "retrieval")

PLATFORM = os.getenv("PLATFORM")

Expand Down Expand Up @@ -89,7 +91,7 @@ def build_cmake(self, ext: CMakeExtension):

subprocess.check_call(cmake_args, cwd=build_dir)

if ext.name in ["store", "gsa_offload_ops"]:
if ext.name in ["store", "gsa_offload_ops", "esaretrieval"]:
subprocess.check_call(["make", "-j", "8"], cwd=build_dir)
else:
# 对于gsa_prefetch使用cmake --build
Expand All @@ -115,6 +117,8 @@ def _copy_so_files(self, ext: CMakeExtension):
search_patterns.extend(["gsa_offload_ops"])
elif ext.name == "gsa_prefetch":
search_patterns.extend(["prefetch"])
elif ext.name == "esaretrieval":
search_patterns.extend(["retrieval_backend"])

for file in os.listdir(so_search_dir):
if file.endswith(".so") or ".so." in file:
Expand All @@ -124,8 +128,11 @@ def _copy_so_files(self, ext: CMakeExtension):
break

if ext.name == "store":
install_dir = STORE_INSTALL_DIR
build_install_dir = STORE_INSTALL_DIR
install_dir = FSSTORE_INSTALL_DIR
build_install_dir = "ucm/store"
elif ext.name == "esaretrieval":
install_dir = RETRIEVAL_INSTALL_DIR
build_install_dir = "ucm/ucm_sparse/retrieval"
else:
install_dir = GSA_INSTALL_DIR
build_install_dir = "ucm/ucm_sparse"
Expand All @@ -134,7 +141,6 @@ def _copy_so_files(self, ext: CMakeExtension):
src_path = os.path.join(so_search_dir, so_file)
dev_path = os.path.join(install_dir, so_file)
dst_path = os.path.join(self.build_lib, build_install_dir, so_file)

os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy(src_path, dst_path)
print(f"[INFO] Copied {so_file} → {dst_path}")
Expand All @@ -149,6 +155,7 @@ def _copy_so_files(self, ext: CMakeExtension):
ext_modules.append(CMakeExtension(name="store", sourcedir=STORE_SRC_DIR))
ext_modules.append(CMakeExtension(name="gsa_offload_ops", sourcedir=GSA_SRC_DIR))
ext_modules.append(CMakeExtension(name="gsa_prefetch", sourcedir=PREFETCH_SRC_DIR))
ext_modules.append(CMakeExtension(name="esaretrieval", sourcedir=RETRIEVAL_SRC_DIR))

setup(
name="ucm",
Expand Down
32 changes: 32 additions & 0 deletions ucm/csrc/esaretrieval/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
cmake_minimum_required(VERSION 3.14)
project(retrieval_backend LANGUAGES CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

include(FetchContent)
FetchContent_Declare(
pybind11
GIT_REPOSITORY https://github.com/pybind/pybind11.git
GIT_TAG v2.13.6
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(pybind11)

pybind11_add_module(retrieval_backend
retrieval_backend.cpp
)

set(OUTPUT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/output)
set_target_properties(retrieval_backend PROPERTIES
PREFIX ""
SUFFIX ".so"
LIBRARY_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/lib"
RUNTIME_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/bin"
ARCHIVE_OUTPUT_DIRECTORY "${OUTPUT_ROOT}/lib"
)

target_compile_options(retrieval_backend PRIVATE -O3 -Wall -fPIC)
target_link_libraries(retrieval_backend PRIVATE Python3::Python)
Loading