diff --git a/lightllm/common/cpu_cache/__init__.py b/lightllm/common/cpu_cache/__init__.py new file mode 100644 index 0000000000..3500289197 --- /dev/null +++ b/lightllm/common/cpu_cache/__init__.py @@ -0,0 +1,3 @@ +from .creator import CpuCacheCreator, CpuCacheTensorSpec + +__all__ = ["CpuCacheCreator", "CpuCacheTensorSpec"] diff --git a/lightllm/common/cpu_cache/creator.py b/lightllm/common/cpu_cache/creator.py new file mode 100644 index 0000000000..7d03f0c89c --- /dev/null +++ b/lightllm/common/cpu_cache/creator.py @@ -0,0 +1,46 @@ +import ctypes +import torch +import numpy as np +from dataclasses import dataclass +from typing import Optional, Tuple +from lightllm.utils.kv_cache_utils import attach_shm_kv_cache_ptr, create_shm_kv_cache_ptr, register_shm_ptr_to_pin + + +@dataclass(frozen=True) +class CpuCacheTensorSpec: + shm_key: int + shape: Tuple[int, ...] + dtype: torch.dtype + size_bytes: int + + +class CpuCacheCreator: + def __init__(self, tensor_spec: CpuCacheTensorSpec): + self.tensor_spec = tensor_spec + + def create_or_attach( + self, init_shm_data: bool, pin: bool, pin_no_blocking: bool + ) -> Tuple[Optional[torch.Tensor], Optional[object]]: + if init_shm_data: + shm_ptr = create_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) + else: + shm_ptr = attach_shm_kv_cache_ptr(key=self.tensor_spec.shm_key, size=self.tensor_spec.size_bytes) + + if pin: + attach_handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.tensor_spec.size_bytes) + # 是否阻塞等待pin 完成 + if not pin_no_blocking: + attach_handle.wait() + cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) + assert shm_ptr == cpu_cache_tensor.data_ptr() + return cpu_cache_tensor, attach_handle + else: + cpu_cache_tensor = self._build_tensor_view(shm_ptr=shm_ptr) + return cpu_cache_tensor, None + + def _build_tensor_view(self, shm_ptr: int) -> torch.Tensor: + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.tensor_spec.size_bytes).from_address(shm_ptr)), + dtype=np.uint8, + ) + return torch.from_numpy(numpy_array).view(dtype=self.tensor_spec.dtype).view(self.tensor_spec.shape) diff --git a/lightllm/server/embed_cache/allocator.py b/lightllm/server/embed_cache/allocator.py new file mode 100644 index 0000000000..6f1d3dddea --- /dev/null +++ b/lightllm/server/embed_cache/allocator.py @@ -0,0 +1,91 @@ +from typing import Optional + +from sortedcontainers import SortedSet + + +class MemoryBlock: + """内存块类,表示一个连续的内存区域""" + + def __init__(self, start, end): + self.start = start + self.end = end + + def size(self): + return self.end - self.start + + def __repr__(self): + return f"Block(start={self.start}, end={self.end})" + + def can_merge(self, block: "MemoryBlock"): + return (self.start == block.end) or (block.start == self.end) + + +class MemoryManager: + def __init__(self, total_size): + """ + 初始化内存管理器 + :param total_size: 总内存大小 + """ + self.total_size = total_size + self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size())) + self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start)) + total = MemoryBlock(0, total_size) + self.__add(total) + + def alloc(self, need_size: int) -> Optional[MemoryBlock]: + assert need_size > 0 + + if len(self.mem_set_by_size) == 0: + return None + + key = MemoryBlock(start=-1, end=-1 + need_size) + find_index = self.mem_set_by_size.bisect_left(key) + if find_index < len(self.mem_set_by_size): + finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index] + self.__del(finded_mem_block) + ret_mem_block = MemoryBlock( + start=finded_mem_block.start, + end=finded_mem_block.start + need_size, + ) + left_block = MemoryBlock( + start=finded_mem_block.start + need_size, + end=finded_mem_block.end, + ) + if left_block.size() > 0: + self.__add(left_block) + + return ret_mem_block + else: + return None + + def release(self, block: MemoryBlock): + if block is None: + return + if len(self.mem_set_by_size) == 0: + self.__add(block) + return + + finded_index = self.mem_set_by_start.bisect_left(block) + for index in [finded_index - 1, finded_index, finded_index + 1]: + if index < len(self.mem_set_by_start): + sub_block: MemoryBlock = self.mem_set_by_start[index] + # merge + if block.can_merge(sub_block): + self.__del(sub_block) + merge_block = MemoryBlock( + start=min(block.start, sub_block.start), + end=max(block.end, sub_block.end), + ) + self.release(merge_block) + return + # 无法merge时,直接add + self.__add(block) + return + + def __add(self, block): + self.mem_set_by_start.add(block) + self.mem_set_by_size.add(block) + + def __del(self, block): + self.mem_set_by_start.remove(block) + self.mem_set_by_size.remove(block) diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index 6fcc2d3783..b72d8b2c5e 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -1,12 +1,11 @@ -import ctypes import torch -import numpy as np -from sortedcontainers import SortedSet -from typing import Optional, List +from typing import Optional from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger from lightllm.utils.embed_utils import calcu_embed_cache_meta -from lightllm.utils.kv_cache_utils import create_shm_kv_cache_ptr, attach_shm_kv_cache_ptr, register_shm_ptr_to_pin +from lightllm.common.cpu_cache import CpuCacheCreator, CpuCacheTensorSpec +from .allocator import MemoryBlock, MemoryManager +from .copy_to_cache import offload_embed_tensor_to_cache logger = init_logger(__name__) @@ -25,10 +24,22 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool): if create_meta_data: self.token_index_manager = MemoryManager(total_size=self.token_num) - if init_shm_data: - self._create_shm_embed_kv_cache() - else: - self._attach_shm_cpu_embed_cache() + cache_tensor_spec = CpuCacheTensorSpec( + shm_key=self.args.multi_modal_cache_shm_id, + shape=( + self.embed_cache_tensor_meta.token_num, + self.embed_cache_tensor_meta.layer_num, + self.embed_cache_tensor_meta.hidden_size, + ), + dtype=self.embed_cache_tensor_meta.data_type, + size_bytes=self.embed_cache_tensor_meta.calcu_size(), + ) + cache_tensor_creator = CpuCacheCreator(tensor_spec=cache_tensor_spec) + self.cpu_embed_cache_tensor, _ = cache_tensor_creator.create_or_attach( + init_shm_data=init_shm_data, + pin=not init_shm_data, + pin_no_blocking=False, + ) return def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]: @@ -39,17 +50,14 @@ def release_indexes(self, block: "MemoryBlock"): return def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): - from .copy_to_cache import offload_embed_tensor_to_cache - offload_embed_tensor_to_cache( embed_tensor=embed_tensor, cache_tensor=self.cpu_embed_cache_tensor, start_index_in_cache=start_index_in_cache, ) + return def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int): - from .copy_to_cache import offload_embed_tensor_to_cache - if embed_tensor.ndim == 3: # check for qwen3 vision embed tensor shape, use apply deepstack assert embed_tensor.shape[1] == self.cpu_embed_cache_tensor.shape[1] @@ -59,123 +67,10 @@ def copy_vision_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: cache_tensor=self.cpu_embed_cache_tensor, start_index_in_cache=start_index_in_cache, ) - - def _create_shm_embed_kv_cache(self): - shm_ptr = create_shm_kv_cache_ptr( - key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size() - ) - logger.info(f"create embed cache shm ptr: {shm_ptr}, size: {self.embed_cache_tensor_meta.calcu_size()}") return - def _attach_shm_cpu_embed_cache(self): - shm_ptr = attach_shm_kv_cache_ptr( - key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size() - ) - handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size()) - handle.wait() - numpy_array = np.frombuffer( - memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), - dtype=np.uint8, - ) - shape = ( - self.embed_cache_tensor_meta.token_num, - self.embed_cache_tensor_meta.layer_num, - self.embed_cache_tensor_meta.hidden_size, - ) - self.cpu_embed_cache_tensor = ( - torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape) - ) - assert shm_ptr == self.cpu_embed_cache_tensor.data_ptr() - return None - - -class MemoryBlock: - """内存块类,表示一个连续的内存区域""" - - def __init__(self, start, end): - self.start = start - self.end = end - - def size(self): - return self.end - self.start - - def __repr__(self): - return f"Block(start={self.start}, end={self.end})" - - def can_merge(self, block: "MemoryBlock"): - return (self.start == block.end) or (block.start == self.end) - - -class MemoryManager: - def __init__(self, total_size): - """ - 初始化内存管理器 - :param total_size: 总内存大小 - """ - self.total_size = total_size - self.mem_set_by_start = SortedSet(key=lambda x: (x.start, x.size())) - self.mem_set_by_size = SortedSet(key=lambda x: (x.size(), x.start)) - total = MemoryBlock(0, total_size) - self.__add(total) - - def alloc(self, need_size: int) -> Optional[MemoryBlock]: - assert need_size > 0 - - if len(self.mem_set_by_size) == 0: - return None - - key = MemoryBlock(start=-1, end=-1 + need_size) - find_index = self.mem_set_by_size.bisect_left(key) - if find_index < len(self.mem_set_by_size): - finded_mem_block: MemoryBlock = self.mem_set_by_size[find_index] - self.__del(finded_mem_block) - ret_mem_block = MemoryBlock( - start=finded_mem_block.start, - end=finded_mem_block.start + need_size, - ) - left_block = MemoryBlock( - start=finded_mem_block.start + need_size, - end=finded_mem_block.end, - ) - if left_block.size() > 0: - self.__add(left_block) - - return ret_mem_block - else: - return None - - def release(self, block: MemoryBlock): - if block is None: - return - if len(self.mem_set_by_size) == 0: - self.__add(block) - return - - finded_index = self.mem_set_by_start.bisect_left(block) - for index in [finded_index - 1, finded_index, finded_index + 1]: - if index < len(self.mem_set_by_start): - sub_block: MemoryBlock = self.mem_set_by_start[index] - # merge - if block.can_merge(sub_block): - self.__del(sub_block) - merge_block = MemoryBlock( - start=min(block.start, sub_block.start), - end=max(block.end, sub_block.end), - ) - self.release(merge_block) - return - # 无法merge时,直接add - self.__add(block) return - def __add(self, block): - self.mem_set_by_start.add(block) - self.mem_set_by_size.add(block) - - def __del(self, block): - self.mem_set_by_start.remove(block) - self.mem_set_by_size.remove(block) - if __name__ == "__main__": mem = MemoryManager(total_size=2000) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index fbce108762..5ad26fbcc8 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -10,7 +10,10 @@ import multiprocessing.shared_memory as shm from ..utils import get_shm_name_data, free_shm from lightllm.utils.log_utils import init_logger -from ..embed_cache_client import CpuEmbedCacheClient, MemoryBlock, SortedSet +from sortedcontainers import SortedSet + +from ..allocator import MemoryBlock +from ..embed_cache_client import CpuEmbedCacheClient logger = init_logger(__name__) diff --git a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py index ae225255af..e4f37c0480 100644 --- a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py +++ b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py @@ -1,17 +1,11 @@ import ctypes -import torch -import numpy as np from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name, get_disk_cache_prompt_limit_length from typing import List, Optional, Tuple from lightllm.utils.log_utils import init_logger +from lightllm.common.cpu_cache import CpuCacheCreator, CpuCacheTensorSpec from .shm_objs import ShmDict, ShmLinkedList, _LinkedListItem, IntList from lightllm.server.core.objs import AtomicShmLock -from lightllm.utils.kv_cache_utils import ( - calcu_cpu_cache_meta, - create_shm_kv_cache_ptr, - attach_shm_kv_cache_ptr, - register_shm_ptr_to_pin, -) +from lightllm.utils.kv_cache_utils import calcu_cpu_cache_meta logger = init_logger(__name__) @@ -30,11 +24,24 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool): self._create_cpu_status_list(init_shm_data) if not only_create_meta_data: - if init_shm_data: - self._create_shm_cpu_kv_cache() - self.attach_shm_handle = None - else: - self.attach_shm_handle = self._attach_shm_cpu_kv_cache() + tensor_spec = CpuCacheTensorSpec( + shm_key=self.args.cpu_kv_cache_shm_id, + shape=( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.get_merged_head_dim(), + ), + dtype=self.kv_cache_tensor_meta.data_type, + size_bytes=self.kv_cache_tensor_meta.calcu_size(), + ) + tensor_creator = CpuCacheCreator(tensor_spec=tensor_spec) + self.cpu_kv_cache_tensor, self.attach_shm_handle = tensor_creator.create_or_attach( + init_shm_data=init_shm_data, + pin=not init_shm_data, + pin_no_blocking=True, + ) return def get_one_empty_page(self, hash_key: int, disk_offload_enable: bool) -> Optional[int]: @@ -274,51 +281,6 @@ def _create_cpu_status_list(self, init_shm_data: bool): ) return - def _create_shm_cpu_kv_cache(self): - shm_ptr = create_shm_kv_cache_ptr( - key=self.args.cpu_kv_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size() - ) - numpy_array = np.frombuffer( - memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 - ) - # 将 NumPy 数组转换为 PyTorch 张量 - shape = ( - self.kv_cache_tensor_meta.page_num, - self.kv_cache_tensor_meta.layer_num, - self.kv_cache_tensor_meta.token_page_size, - self.kv_cache_tensor_meta.num_heads, - self.kv_cache_tensor_meta.get_merged_head_dim(), - ) - self.cpu_kv_cache_tensor = ( - torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) - ) - return - - def _attach_shm_cpu_kv_cache(self): - shm_ptr = attach_shm_kv_cache_ptr( - key=self.args.cpu_kv_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size() - ) - handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size()) - numpy_array = np.frombuffer( - memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 - ) - shape = ( - self.kv_cache_tensor_meta.page_num, - self.kv_cache_tensor_meta.layer_num, - self.kv_cache_tensor_meta.token_page_size, - self.kv_cache_tensor_meta.num_heads, - self.kv_cache_tensor_meta.get_merged_head_dim(), - ) - self.cpu_kv_cache_tensor = ( - torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) - ) - assert shm_ptr == self.cpu_kv_cache_tensor.data_ptr() - - # test code - # self.cpu_kv_cache_tensor = torch.zeros_like(self.cpu_kv_cache_tensor, device="cpu", pin_memory=True) - # self.cpu_kv_cache_tensor = torch.zeros_like(self.cpu_kv_cache_tensor, device="cuda") - return handle - class _CpuPageStatus(_LinkedListItem): _pack_ = 4