diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index fee96acd9..5edf6ad20 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -61,6 +61,8 @@ def __init__(self, kvargs): self.finetune_config = kvargs.get("finetune_config", None) self.max_req_num = kvargs.get("max_req_num", 1000) self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5) + # 用于等待外围的一些模块的初始化完成(如 CPU KV Cache 注册完成) + self.wait_events = kvargs.get("wait_events", []) # is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效 # 主要是在prefill阶段返回多少个token的用于后续处理相关。 self.is_token_healing = kvargs.get("is_token_healing", False) @@ -110,12 +112,19 @@ def __init__(self, kvargs): self._init_inferstate_cls() self._autotune_warmup() self._init_padded_req() + # wait必须在init cudagraph 之前,避免错误捕获 + self._wait_other_modules_ready() self._init_cudagraph() self._check_max_len_infer() torch.cuda.empty_cache() set_model_init_status(True) return + def _wait_other_modules_ready(self): + for event in self.wait_events: + event.wait() + return + def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) @@ -352,8 +361,13 @@ def _prefill( alloc_mem_index=infer_state.mem_index, max_q_seq_len=infer_state.max_q_seq_len, ) + prefill_mem_indexes_ready_event = torch.cuda.Event() + prefill_mem_indexes_ready_event.record() + infer_state.init_some_extra_state(self, model_input.input_ids) - return self._context_forward(model_input.input_ids, infer_state) + model_output = self._context_forward(model_input.input_ids, infer_state) + model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event + return model_output def _decode( self, @@ -505,6 +519,9 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod ) infer_state1.init_some_extra_state(self, input_ids1) + prefill_mem_indexes_ready_event = torch.cuda.Event() + prefill_mem_indexes_ready_event.record() + model_output0, model_output1 = self._overlap_tpsp_context_forward( input_ids0, infer_state0, input_ids1=input_ids1, infer_state1=infer_state1 ) @@ -512,6 +529,8 @@ def microbatch_overlap_prefill(self, model_input0: ModelInput, model_input1: Mod # 在开启使用deepep的时候,需要调用clear_deepep_buffer做资源清理,没有启用的时候 # 该调用没有实际意义 dist_group_manager.clear_deepep_buffer() + model_output0.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event + model_output1.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event return model_output0, model_output1 @torch.no_grad() diff --git a/lightllm/common/basemodel/batch_objs.py b/lightllm/common/basemodel/batch_objs.py index 7bf511380..9e6203b6e 100644 --- a/lightllm/common/basemodel/batch_objs.py +++ b/lightllm/common/basemodel/batch_objs.py @@ -58,6 +58,8 @@ def to_cuda(self): class ModelOutput: # 通用变量 logits: torch.Tensor + # 用于判断 mem_indexes 是否成功写入 req manager 中的事件对象。 + prefill_mem_indexes_ready_event: torch.Event = None # 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊 # 的输出变量。只在特殊的模型模式下才会具体使用和生效。 diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py new file mode 100644 index 000000000..c6098f950 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -0,0 +1,531 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _offload_gpu_kv_to_cpu( + token_indexes_ptr, + gpu_kv_cache_ptr, + gpu_stride0, + gpu_stride1, + gpu_stride2, + gpu_stride3, + cpu_kv_cache_ptr, + cpu_stride0, + cpu_stride1, + cpu_stride2, + cpu_stride3, + cpu_stride4, + page_indexes_ptr, + page_readies_ptr, + layer_num, + head_dim, + block_num, + cpu_k_start_head_index: tl.constexpr, + cpu_k_head_num: tl.constexpr, + gpu_k_start_head_index: tl.constexpr, + gpu_k_head_num: tl.constexpr, + cpu_v_start_head_index: tl.constexpr, + cpu_v_head_num: tl.constexpr, + gpu_v_start_head_index: tl.constexpr, + gpu_v_head_num: tl.constexpr, + BLOCK_HEAD_DIM: tl.constexpr, + TOKEN_BLOCK: tl.constexpr, +): + block_start_index = tl.program_id(0) + block_split_size = tl.num_programs(axis=0) + + for block_index in tl.range(block_start_index, block_num, block_split_size): + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + + ready_state = tl.load(page_readies_ptr + block_index) + + mask_layer_num = tl.where(cpu_page_index == -1, 0, 1) + mask_layer_num = tl.where(ready_state, 0, mask_layer_num) + + token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) + token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64) + head_dim_range = tl.arange(0, BLOCK_HEAD_DIM) + head_dim_mask = head_dim_range < head_dim + + for layer_index in range(layer_num * mask_layer_num): + for k_head_index in range(gpu_k_head_num): + gpu_k_head_index = k_head_index + gpu_k_start_head_index + cpu_k_head_index = k_head_index + cpu_k_start_head_index + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + gpu_k_head_index.to(tl.int64) * gpu_stride2 + + head_dim_range[None, :] + ) + gpu_data = tl.load(gpu_ptr, mask=head_dim_mask[None, :], other=0.0) + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + cpu_k_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + tl.store( + cpu_ptr, + gpu_data, + mask=head_dim_mask[None, :], + cache_modifier=".wt", + ) + + for v_head_index in range(gpu_v_head_num): + gpu_v_head_index = v_head_index + gpu_v_start_head_index + cpu_v_head_index = v_head_index + cpu_v_start_head_index + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + gpu_v_head_index.to(tl.int64) * gpu_stride2 + + head_dim_range[None, :] + ) + gpu_data = tl.load(gpu_ptr, mask=head_dim_mask[None, :], other=0.0) + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + cpu_v_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + tl.store( + cpu_ptr, + gpu_data, + mask=head_dim_mask[None, :], + cache_modifier=".wt", + ) + return + + +@torch.no_grad() +def offload_gpu_kv_to_cpu( + token_indexes: torch.Tensor, + gpu_kv_cache: torch.Tensor, + cpu_kv_cache: torch.Tensor, + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + _cache_data={}, +): + """ + this function is used to offload GPU KV cache to CPU KV cache. + Args: + token_indexes: (token_num,) + gpu_kv_cache: (layer_num, token_num, head_num, head_dim) + cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim) + page_indexes: (page_num,) + page_readies: (page_num,) + """ + token_block_size = cpu_kv_cache.shape[2] + token_num = token_indexes.shape[0] + assert token_num == page_indexes.shape[0] * token_block_size + assert page_indexes.shape == page_readies.shape + + gpu_heads = gpu_kv_cache.shape[2] + gpu_head_dim = gpu_kv_cache.shape[3] + cpu_heads = cpu_kv_cache.shape[3] + cpu_head_dim = cpu_kv_cache.shape[4] + assert gpu_head_dim == cpu_head_dim + assert gpu_kv_cache.shape[0] == cpu_kv_cache.shape[1] + head_dim = gpu_head_dim + scale_size = (tp_world_size * gpu_heads) // cpu_heads + + # 计算需要拷贝的 head 索引的对应关系 + if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data: + need_offload, head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] + else: + if cpu_heads > 1: + assert (tp_world_size * gpu_heads) % cpu_heads == 0 + assert cpu_heads % 2 == 0 + + cpu_heads_index = ( + torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32) + .view(cpu_heads, 1) + .tile((1, scale_size)) + .view(2, tp_world_size, -1) + ) + # k + k_cpu_heads_index = cpu_heads_index[0][tp_index] + # v + v_cpu_heads_index = cpu_heads_index[1][tp_index] + + cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy() + gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1).numpy() + + need_offload = tp_index % scale_size == 0 + + cpu_k_start_head_index = int(cpu_heads_index[0, 0]) + cpu_k_head_num = len(cpu_heads_index[0]) + gpu_k_start_head_index = int(gpu_heads_index[0, 0]) + gpu_k_head_num = len(gpu_heads_index[0]) + assert cpu_k_head_num == gpu_k_head_num + cpu_v_start_head_index = int(cpu_heads_index[1, 0]) + cpu_v_head_num = len(cpu_heads_index[1]) + gpu_v_start_head_index = int(gpu_heads_index[1, 0]) + gpu_v_head_num = len(gpu_heads_index[1]) + assert cpu_v_head_num == gpu_v_head_num + + head_info_tuple = ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) + + else: + assert gpu_heads == 1 + assert cpu_heads == 1 + + need_offload = tp_index == 0 + + cpu_k_start_head_index = 0 + cpu_k_head_num = 1 + gpu_k_start_head_index = 0 + gpu_k_head_num = 1 + cpu_v_start_head_index = 0 + cpu_v_head_num = 0 + gpu_v_start_head_index = 0 + gpu_v_head_num = 0 + head_info_tuple = ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) + + _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = (need_offload, head_info_tuple) + + ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) = head_info_tuple + + if not need_offload: + return + + assert token_block_size == triton.next_power_of_2(token_block_size) + page_num = page_indexes.shape[0] + + grid = (grid_num,) + num_warps = 4 + + _offload_gpu_kv_to_cpu[grid]( + token_indexes_ptr=token_indexes, + gpu_kv_cache_ptr=gpu_kv_cache, + gpu_stride0=gpu_kv_cache.stride(0), + gpu_stride1=gpu_kv_cache.stride(1), + gpu_stride2=gpu_kv_cache.stride(2), + gpu_stride3=gpu_kv_cache.stride(3), + cpu_kv_cache_ptr=cpu_kv_cache, + cpu_stride0=cpu_kv_cache.stride(0), + cpu_stride1=cpu_kv_cache.stride(1), + cpu_stride2=cpu_kv_cache.stride(2), + cpu_stride3=cpu_kv_cache.stride(3), + cpu_stride4=cpu_kv_cache.stride(4), + page_indexes_ptr=page_indexes, + page_readies_ptr=page_readies, + layer_num=gpu_kv_cache.shape[0], + head_dim=head_dim, + block_num=page_num, + cpu_k_start_head_index=cpu_k_start_head_index, + cpu_k_head_num=cpu_k_head_num, + gpu_k_start_head_index=gpu_k_start_head_index, + gpu_k_head_num=gpu_k_head_num, + cpu_v_start_head_index=cpu_v_start_head_index, + cpu_v_head_num=cpu_v_head_num, + gpu_v_start_head_index=gpu_v_start_head_index, + gpu_v_head_num=gpu_v_head_num, + BLOCK_HEAD_DIM=triton.next_power_of_2(head_dim), + TOKEN_BLOCK=token_block_size, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _load_cpu_cache_to_gpu( + gpu_mem_indexes_ptr, + copy_token_num, + copy_block_num, + cpu_mem_indexes_ptr, + cpu_page_indexes_ptr, + gpu_kv_cache_ptr, + gpu_stride0, + gpu_stride1, + gpu_stride2, + gpu_stride3, + cpu_kv_cache_ptr, + cpu_stride0, + cpu_stride1, + cpu_stride2, + cpu_stride3, + cpu_stride4, + layer_num, + head_dim, + cpu_k_start_head_index: tl.constexpr, + cpu_k_head_num: tl.constexpr, + gpu_k_start_head_index: tl.constexpr, + gpu_k_head_num: tl.constexpr, + cpu_v_start_head_index: tl.constexpr, + cpu_v_head_num: tl.constexpr, + gpu_v_start_head_index: tl.constexpr, + gpu_v_head_num: tl.constexpr, + BLOCK_HEAD_DIM: tl.constexpr, + TOKEN_BLOCK: tl.constexpr, +): + block_index_start = tl.program_id(0) + split_block_num = tl.num_programs(0) + for block_index in range(block_index_start, copy_block_num, split_block_num): + token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) + token_mask = token_range < copy_token_num + gpu_mem_indexes = tl.load(gpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64) + cpu_mem_indexes = tl.load(cpu_mem_indexes_ptr + token_range, mask=token_mask).to(tl.int64) + cpu_page_indexes = tl.load(cpu_page_indexes_ptr + token_range, mask=token_mask).to(tl.int64) + + head_dim_range = tl.arange(0, BLOCK_HEAD_DIM) + head_dim_mask = head_dim_range < head_dim + + for layer_index in range(layer_num): + move_mask = token_mask[:, None] & head_dim_mask[None, :] + + for k_head_index in range(cpu_k_head_num): + gpu_k_head_index = k_head_index + gpu_k_start_head_index + cpu_k_head_index = k_head_index + cpu_k_start_head_index + + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_indexes[:, None] * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + cpu_mem_indexes[:, None] * cpu_stride2 + + cpu_k_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0) + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + gpu_mem_indexes[:, None] * gpu_stride1 + + gpu_k_head_index * gpu_stride2 + + head_dim_range[None, :] + ) + + tl.store( + gpu_ptr, + cpu_data, + mask=move_mask, + ) + + for v_head_index in range(cpu_v_head_num): + gpu_v_head_index = v_head_index + gpu_v_start_head_index + cpu_v_head_index = v_head_index + cpu_v_start_head_index + + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_indexes[:, None] * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + cpu_mem_indexes[:, None] * cpu_stride2 + + cpu_v_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + cpu_data = tl.load(cpu_ptr, mask=move_mask, other=0.0) + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + gpu_mem_indexes[:, None] * gpu_stride1 + + gpu_v_head_index * gpu_stride2 + + head_dim_range[None, :] + ) + + tl.store( + gpu_ptr, + cpu_data, + mask=move_mask, + ) + return + + +@torch.no_grad() +def load_cpu_kv_to_gpu( + gpu_mem_indexes: torch.Tensor, + gpu_kv_cache: torch.Tensor, + cpu_kv_cache: torch.Tensor, + page_indexes: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + _cache_data={}, +): + """ + this function is used to offload GPU KV cache to CPU KV cache. + Args: + gpu_mem_indexes: (token_num,) + gpu_kv_cache: (layer_num, all_token_num, head_num, head_dim) + cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim) + page_indexes: (page_num,) + """ + token_block_size = cpu_kv_cache.shape[2] + cpu_page_num = page_indexes.shape[0] + cpu_page_all_token_num = cpu_page_num * token_block_size + assert gpu_mem_indexes.shape[0] <= cpu_page_all_token_num + move_token_num = gpu_mem_indexes.shape[0] + + cpu_page_indexes = page_indexes.view((cpu_page_num, 1)).tile((1, token_block_size)).view(-1) + cpu_mem_indexes = torch.arange(0, cpu_page_all_token_num, device="cuda", dtype=torch.int32) % token_block_size + cpu_page_indexes = cpu_page_indexes[-move_token_num:] + cpu_mem_indexes = cpu_mem_indexes[-move_token_num:] + + gpu_heads = gpu_kv_cache.shape[2] + gpu_head_dim = gpu_kv_cache.shape[3] + cpu_heads = cpu_kv_cache.shape[3] + cpu_head_dim = cpu_kv_cache.shape[4] + assert gpu_head_dim == cpu_head_dim + head_dim = gpu_head_dim + scale_size = (tp_world_size * gpu_heads) // cpu_heads + + # 计算需要拷贝的 head 索引的对应关系 + if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data: + head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] + else: + if cpu_heads > 1: + assert (tp_world_size * gpu_heads) % cpu_heads == 0 + assert cpu_heads % 2 == 0 + + cpu_heads_index = ( + torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32) + .view(cpu_heads, 1) + .tile((1, scale_size)) + .view(2, tp_world_size, -1) + ) + # k + k_cpu_heads_index = cpu_heads_index[0][tp_index] + # v + v_cpu_heads_index = cpu_heads_index[1][tp_index] + + cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy() + gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1).numpy() + + cpu_k_start_head_index = int(cpu_heads_index[0, 0]) + cpu_k_head_num = len(cpu_heads_index[0]) + gpu_k_start_head_index = int(gpu_heads_index[0, 0]) + gpu_k_head_num = len(gpu_heads_index[0]) + assert cpu_k_head_num == gpu_k_head_num + cpu_v_start_head_index = int(cpu_heads_index[1, 0]) + cpu_v_head_num = len(cpu_heads_index[1]) + gpu_v_start_head_index = int(gpu_heads_index[1, 0]) + gpu_v_head_num = len(gpu_heads_index[1]) + assert cpu_v_head_num == gpu_v_head_num + + head_info_tuple = ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) + + else: + assert gpu_heads == 1 + assert cpu_heads == 1 + + cpu_k_start_head_index = 0 + cpu_k_head_num = 1 + gpu_k_start_head_index = 0 + gpu_k_head_num = 1 + cpu_v_start_head_index = 0 + cpu_v_head_num = 0 + gpu_v_start_head_index = 0 + gpu_v_head_num = 0 + head_info_tuple = ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) + + _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = head_info_tuple + + ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) = head_info_tuple + + TOKEN_BLOCK = 128 + + grid = (grid_num,) + num_warps = 4 + + _load_cpu_cache_to_gpu[grid]( + gpu_mem_indexes_ptr=gpu_mem_indexes, + copy_token_num=move_token_num, + copy_block_num=triton.cdiv(move_token_num, TOKEN_BLOCK), + cpu_mem_indexes_ptr=cpu_mem_indexes, + cpu_page_indexes_ptr=cpu_page_indexes, + gpu_kv_cache_ptr=gpu_kv_cache, + gpu_stride0=gpu_kv_cache.stride(0), + gpu_stride1=gpu_kv_cache.stride(1), + gpu_stride2=gpu_kv_cache.stride(2), + gpu_stride3=gpu_kv_cache.stride(3), + cpu_kv_cache_ptr=cpu_kv_cache, + cpu_stride0=cpu_kv_cache.stride(0), + cpu_stride1=cpu_kv_cache.stride(1), + cpu_stride2=cpu_kv_cache.stride(2), + cpu_stride3=cpu_kv_cache.stride(3), + cpu_stride4=cpu_kv_cache.stride(4), + layer_num=gpu_kv_cache.shape[0], + head_dim=head_dim, + cpu_k_start_head_index=cpu_k_start_head_index, + cpu_k_head_num=cpu_k_head_num, + gpu_k_start_head_index=gpu_k_start_head_index, + gpu_k_head_num=gpu_k_head_num, + cpu_v_start_head_index=cpu_v_start_head_index, + cpu_v_head_num=cpu_v_head_num, + gpu_v_start_head_index=gpu_v_start_head_index, + gpu_v_head_num=gpu_v_head_num, + BLOCK_HEAD_DIM=triton.next_power_of_2(head_dim), + TOKEN_BLOCK=TOKEN_BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/distributed/custom_all_gather.py b/lightllm/distributed/custom_all_gather.py index e00a2da22..44c72fcda 100644 --- a/lightllm/distributed/custom_all_gather.py +++ b/lightllm/distributed/custom_all_gather.py @@ -28,7 +28,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import has_nvlink from lightllm.utils.light_utils import light_ops -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager try: diff --git a/lightllm/distributed/custom_all_reduce.py b/lightllm/distributed/custom_all_reduce.py index bdcd9b6e8..695d0ca08 100644 --- a/lightllm/distributed/custom_all_reduce.py +++ b/lightllm/distributed/custom_all_reduce.py @@ -29,7 +29,6 @@ from lightllm.utils.device_utils import has_nvlink from lightllm.utils.sgl_utils import sgl_allreduce_ops from lightllm.utils.vllm_utils import vllm_ops -from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager logger = init_logger(__name__) @@ -225,6 +224,9 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: buffer. """ if out is None: + # fix circle import + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + out = g_cache_manager.alloc_tensor(inp.shape, inp.dtype, device=inp.device, is_graph_out=False) if registered: ops.all_reduce(self._ptr, inp, out, 0, 0) @@ -243,6 +245,9 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. + # fix circle import + from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + out = g_cache_manager.alloc_tensor(input.shape, input.dtype, device=input.device, is_graph_out=False) return out else: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 8fa519578..df4641a65 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -519,4 +519,25 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--enable_cpu_cache", + action="store_true", + help="""enable cpu cache to store kv cache. prefer to use hugepages for better performance.""", + ) + parser.add_argument( + "--cpu_cache_storage_size", + type=float, + default=2, + help="""The capacity of cpu cache. GB used.""", + ) + parser.add_argument( + "--cpu_cache_token_page_size", + type=int, + default=256, + help="""The token page size of cpu cache""", + ) + parser.add_argument("--enable_disk_cache", action="store_true", help="""enable disk cache to store kv cache.""") + parser.add_argument( + "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 20548509f..8bda50fb7 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -39,6 +39,7 @@ from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.core.objs import StartArgs from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster @@ -49,7 +50,6 @@ from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name from dataclasses import dataclass -from lightllm.server.core.objs.start_args_type import StartArgs from .api_openai import chat_completions_impl, completions_impl from .api_models import ( @@ -73,7 +73,7 @@ class G_Objs: httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None shared_token_load: TokenLoad = None - def set_args(self, args): + def set_args(self, args: StartArgs): self.args = args from .api_lightllm import lightllm_generate, lightllm_generate_stream from .api_tgi import tgi_generate_impl, tgi_generate_stream_impl @@ -90,22 +90,13 @@ def set_args(self, args): if args.run_mode == "pd_master": self.metric_client = MetricClient(args.metric_port) self.httpserver_manager = HttpServerManagerForPDMaster( - args, - metric_port=args.metric_port, + args=args, ) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = HttpServerManager( - args, - router_port=args.router_port, - cache_port=args.cache_port, - detokenization_pub_port=args.detokenization_pub_port, - visual_port=args.visual_port, - enable_multimodal=args.enable_multimodal, - metric_port=args.metric_port, - ) + self.httpserver_manager = HttpServerManager(args=args) dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 5bd61666e..8557be579 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -74,6 +74,10 @@ def normal_or_p_d_start(args): if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: return + if args.enable_cpu_cache: + # 生成一个用于创建cpu kv cache的共享内存id。 + args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 + assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 if args.zmq_mode == "ipc:///tmp/": @@ -214,19 +218,20 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=8 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, detokenization_port, - detokenization_pub_port, + http_server_port, visual_port, audio_port, cache_port, metric_port, - ) = can_use_ports[0:7] - can_use_ports = can_use_ports[7:] + multi_level_kv_cache_port, + ) = can_use_ports[0:8] + can_use_ports = can_use_ports[8:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -237,11 +242,12 @@ def normal_or_p_d_start(args): # 将申请好的端口放入args参数中 args.router_port = router_port args.detokenization_port = detokenization_port - args.detokenization_pub_port = detokenization_pub_port + args.http_server_port = http_server_port args.visual_port = visual_port args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port + args.multi_level_kv_cache_port = multi_level_kv_cache_port # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] @@ -268,50 +274,51 @@ def normal_or_p_d_start(args): start_funcs=[ start_cache_manager, ], - start_args=[(cache_port, args)], + start_args=[(args,)], ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + if args.enable_multimodal_audio: from .audioserver.manager import start_audio_process - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, audio_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) process_manager.start_submodule_processes( start_funcs=[ start_audio_process, ], start_args=[ - (args, router_port, audio_port, cache_port), + (args,), ], ) - else: - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, router_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) + if args.enable_cpu_cache: + from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager + + process_manager.start_submodule_processes( + start_funcs=[ + start_multi_level_kv_cache_manager, + ], + start_args=[(args,)], + ) process_manager.start_submodule_processes( start_funcs=[ start_metric_manager, ], - start_args=[(metric_port, args)], + start_args=[(args,)], ) process_manager.start_submodule_processes( start_funcs=[start_router_process, start_detokenization_process], start_args=[ - (args, router_port, detokenization_port, metric_port), - (args, detokenization_port, detokenization_pub_port), + (args,), + (args,), ], ) @@ -381,7 +388,7 @@ def pd_master_start(args): start_funcs=[ start_metric_manager, ], - start_args=[(metric_port, args)], + start_args=[(args,)], ) command = [ diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index 709ea5ca2..451e22c90 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -11,7 +11,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.utils.log_utils import init_logger from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes -from lightllm.server.core.objs.shm_req_manager import ShmReqManager +from lightllm.server.core.objs.shm_req_manager import ShmReqManager, StartArgs from lightllm.server.multimodal_params import AudioItem from .model_infer.model_rpc import start_model_process, AudioModelRpcClient from lightllm.utils.graceful_utils import graceful_registry @@ -24,20 +24,22 @@ class AudioManager: def __init__( self, - args, - router_port, - audio_port, - cache_port, + args: StartArgs, infer_batch_size=4, ): context = zmq.asyncio.Context(2) - self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") - self.recv_from_visualserver = context.socket(zmq.PULL) - self.recv_from_visualserver.bind(f"{args.zmq_mode}127.0.0.1:{audio_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.cache_port = cache_port + if args.enable_cpu_cache: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + else: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) + self.cache_port = args.cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp @@ -93,7 +95,7 @@ async def loop_for_fwd(self): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) continue multimodal_params = group_req_indexes.multimodal_params @@ -113,24 +115,26 @@ async def loop_for_fwd(self): await self.infer_audios(audios_need_infer) audios_need_infer = [] for _group_req_indexes in processing_group_reqs: - self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj( + _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL + ) processing_group_reqs = [] if len(audios_need_infer) == 0: - self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) else: processing_group_reqs.append(group_req_indexes) if len(audios_need_infer) > 0: await self.infer_audios(audios_need_infer) for _group_req_indexes in processing_group_reqs: - self.send_to_router.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) processing_group_reqs = [] audios_need_infer = [] async def loop_for_netio_req(self): while True: - recv_req: GroupReqIndexes = await self.recv_from_visualserver.recv_pyobj() + recv_req: GroupReqIndexes = await self.zmq_recv_socket.recv_pyobj() if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -144,13 +148,13 @@ def clean_up(self): return -def start_audio_process(args, router_port, audio_port, cache_port, pipe_writer): +def start_audio_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::audio_server") try: - audioserver = AudioManager(args, router_port, audio_port, cache_port) + audioserver = AudioManager(args=args) asyncio.run(audioserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index 5594df6a0..165e565aa 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -3,3 +3,4 @@ from .shm_req_manager import ShmReqManager from .rpc_shm import RpcShmParams, RpcShmResults, ShmSyncStatusArray from .start_args_type import StartArgs +from .atomic_lock import AtomicShmLock diff --git a/lightllm/server/core/objs/atomic_lock.py b/lightllm/server/core/objs/atomic_lock.py index 7066e206d..cc5725473 100644 --- a/lightllm/server/core/objs/atomic_lock.py +++ b/lightllm/server/core/objs/atomic_lock.py @@ -1,4 +1,5 @@ import atomics +import time from multiprocessing import shared_memory from lightllm.utils.log_utils import init_logger from lightllm.utils.shm_utils import create_or_link_shm @@ -25,3 +26,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): while not a.cmpxchg_weak(1, 0): pass return False + + # acquire_sleep1ms 和 release 是某些特定场景下主动使用进行锁获取的操作函数 + def acquire_sleep1ms(self): + with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a: + while not a.cmpxchg_weak(0, 1): + logger.warning("acquire_sleep1ms wait for 1ms") + time.sleep(0.001) + pass + + def release(self): + with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a: + while not a.cmpxchg_weak(1, 0): + pass + return diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f04bb7ba2..f7fbfb974 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -5,9 +5,11 @@ from .sampling_params import SamplingParams from .out_token_circlequeue import CircularQueue from .shm_array import ShmArray +from .token_chunck_hash_list import TokenHashList, CpuCachePageList from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.kv_cache_utils import compute_token_list_hash from typing import List, Any, Union @@ -77,7 +79,8 @@ class Req(ctypes.Structure): # 虽然某种程度上 cur_output_len 也有同样的功能,但是为了避免多进程访问导致的问题,添加 # candetoken_out_len 变量单独传输这个信息。 ("candetoken_out_len", ctypes.c_int), - ("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计 + ("prompt_cache_len", ctypes.c_int), # 用于记录prompt cache 的命中长度,用于统计,这里指gpu kv cache命中长度 + ("cpu_prompt_cache_len", ctypes.c_int), # 用于记录在 enable_cpu_cache 的场景下,命中的 cpu kv cache 的长度 ("is_paused", ctypes.c_bool), # 标记一个Req因为显存资源管理的原因被临时暂停了。 ("finish_status", FinishStatus), # 这个标记变量是http_server 写入,其他进程读取,用于标记该请求是否因为断网被aborted。 @@ -107,6 +110,12 @@ class Req(ctypes.Structure): # 当 stop_str_matched 条件满足的时候,对应的最后一个生成 token 所在的index位置。 # 该变量为 detokenization 进程写入,http_server 读取 ("stop_str_matched_token_index", ctypes.c_int), + # 用于在开启cpu cache 或者 硬盘 cache时,预先计算,分块输入token的hash值。 + ("token_hash_list", TokenHashList), + # 用于保存查找匹配到的可以被复用的cpu cache 页面信息。 + ("cpu_cache_match_page_indexes", CpuCachePageList), + # 分块hash的块大小 + ("cpu_cache_token_page_size", ctypes.c_int), ] def get_str(self): @@ -139,6 +148,7 @@ def init( self.shm_cur_output_len = 0 self.candetoken_out_len = 0 self.prompt_cache_len = 0 + self.cpu_prompt_cache_len = 0 self.finish_token_index = -1 self.can_released_mark = False self.reward_score = math.nan @@ -164,10 +174,23 @@ def init( self.post_init() + self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size + if get_env_start_args().enable_cpu_cache: + self._fill_input_token_hash() + return + def post_init(self): # 子类继承进行一些额外的初始化操作 pass + def _fill_input_token_hash(self): + self.token_hash_list = TokenHashList() + self.token_hash_list.clear() + hash_values = compute_token_list_hash(self.get_prompt_ids(), self.cpu_cache_token_page_size) + self.token_hash_list.fill(hash_values) + self.cpu_cache_match_page_indexes = CpuCachePageList() + return + def create_prompt_ids_shm_array(self): service_uni_name = get_unique_server_name() name = f"{service_uni_name}_shm_prompts_{self.index_in_shm_mem}" diff --git a/lightllm/server/core/objs/shm_array.py b/lightllm/server/core/objs/shm_array.py index b199dbfc0..c5ad512c6 100644 --- a/lightllm/server/core/objs/shm_array.py +++ b/lightllm/server/core/objs/shm_array.py @@ -11,7 +11,7 @@ def __init__(self, name, shape, dtype): self.shm = None self.arr = None self.name = name - self.dtype_byte_num = np.array([1], dtype=dtype).dtype.itemsize + self.dtype_byte_num = np.dtype(dtype=dtype).itemsize self.dest_size = np.prod(shape) * self.dtype_byte_num self.shape = shape self.dtype = dtype diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ce5ef56fc..5404c5209 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -104,3 +104,24 @@ class StartArgs: nixl_pd_kv_page_num: int = field(default=16) nixl_pd_kv_page_size: int = field(default=1024) pd_node_id: int = field(default=-1) + enable_cpu_cache: bool = field(default=False) + cpu_cache_storage_size: float = field(default=2) + cpu_cache_token_page_size: int = field(default=64) + enable_disk_cache: bool = field(default=False) + disk_cache_storage_size: float = field(default=10) + # zmp ports + router_port: int = field(default=None) + detokenization_port: int = field(default=None) + http_server_port: int = field(default=None) + visual_port: int = field(default=None) + audio_port: int = field(default=None) + cache_port: int = field(default=None) + metric_port: int = field(default=None) + multinode_httpmanager_port: int = field(default=12345) + multi_level_kv_cache_port: int = field(default=None) + # multi_modal + enable_multimodal: bool = field(default=False) + enable_multimodal_audio: bool = field(default=False) + + # kernel setting + enable_fa3: bool = field(default=False) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py new file mode 100644 index 000000000..245ca5b98 --- /dev/null +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -0,0 +1,76 @@ +import os +import ctypes +from typing import List +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +LIGHTLLM_TOKEN_HASH_LIST_SIZE = int(os.getenv("LIGHTLLM_TOKEN_HASH_LIST_SIZE", 2048)) + + +class TokenHashList(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("items", ctypes.c_uint64 * LIGHTLLM_TOKEN_HASH_LIST_SIZE), # 元素静态数组 + ("size", ctypes.c_int), # 队列大小 + ] + + def __init__(self): + # 初始化头和尾 + self.size = 0 + return + + def is_empty(self): + return self.size == 0 + + def is_full(self): + return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE + + def fill(self, data: List[int]): + if len(data) > LIGHTLLM_TOKEN_HASH_LIST_SIZE: + logger.warning( + f"Queue capcity is smaller than data size ({len(data)} > {LIGHTLLM_TOKEN_HASH_LIST_SIZE}), " + f"remove tail to write" + ) + data = data[0:LIGHTLLM_TOKEN_HASH_LIST_SIZE] + self.items[0 : len(data)] = data + self.size = len(data) + return + + def clear(self): + self.size = 0 + + def get_all(self): + return list(self.items[0 : self.size]) + + +class CpuCachePageList(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("items", ctypes.c_int * LIGHTLLM_TOKEN_HASH_LIST_SIZE), # 元素静态数组 + ("size", ctypes.c_int), # 队列大小 + ] + + def __init__(self): + # 初始化头和尾 + self.size = 0 + return + + def is_empty(self): + return self.size == 0 + + def is_full(self): + return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE + + def fill(self, data: List[int]): + assert self.size == 0 + assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE + self.items[0 : len(data)] = data + self.size = len(data) + return + + def clear(self): + self.size = 0 + + def get_all(self): + return list(self.items[0 : self.size]) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 274bdb040..389171ba8 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -5,7 +5,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import zmq import inspect -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs from lightllm.server.core.objs.io_objs import GroupReqIndexes from lightllm.utils.graceful_utils import graceful_registry from typing import Union, Dict, List @@ -24,26 +24,20 @@ class DeTokenizationManager: def __init__( self, - args, - eos_id, - model_weightdir, - tokenizor_mode, - detokenization_port, - detokenization_pub_port, - trust_remote_code, + args: StartArgs, ): self.args = args context = zmq.Context(2) - self.recv_from_router = context.socket(zmq.PULL) - self.recv_from_router.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") self.pub_to_httpserver = context.socket(zmq.PUB) - self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{detokenization_pub_port}") + self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") - self.tokenizer = get_tokenizer(model_weightdir, tokenizor_mode, trust_remote_code=trust_remote_code) + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} - self.eos_id = eos_id + self.eos_id = args.eos_id self._init_get_token_id_to_token_str() self.is_pd_decode_mode = self.args.run_mode == "decode" self.shm_req_manager = ShmReqManager() @@ -80,7 +74,7 @@ def handle_loop(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(recv_max_count): - recv_obj: GroupReqIndexes = self.recv_from_router.recv_pyobj(zmq.NOBLOCK) + recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) assert isinstance(recv_obj, GroupReqIndexes) self._add_new_group_req_index(recv_obj=recv_obj) @@ -172,20 +166,14 @@ def remove_finished_reqs(self): return -def start_detokenization_process(args, detokenization_port, detokenization_pub_port, pipe_writer): +def start_detokenization_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::detokenization_server") try: manager = DeTokenizationManager( - args, - args.eos_id, - args.model_dir, - args.tokenizer_mode, - detokenization_port=detokenization_port, - detokenization_pub_port=detokenization_pub_port, - trust_remote_code=args.trust_remote_code, + args=args, ) except Exception as e: pipe_writer.send(str(e)) diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index f0b68c45e..34eec3b4f 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -3,6 +3,7 @@ import inspect import setproctitle from typing import Union, Optional +from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache from rpyc.utils.classic import obtain @@ -51,7 +52,7 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: return self._impl.get_items_embed(ids) -def start_cache_manager(port: int, args, pipe_writer): +def start_cache_manager(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -60,7 +61,7 @@ def start_cache_manager(port: int, args, pipe_writer): service = CacheServer(manager) from rpyc.utils.server import ThreadedServer - t = ThreadedServer(service, port=port, protocol_config={"allow_pickle": True}) + t = ThreadedServer(service, port=args.cache_port, protocol_config={"allow_pickle": True}) pipe_writer.send("init ok") t.start() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index dc04d081f..8c07e6e8c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -43,17 +43,11 @@ class HttpServerManager: def __init__( self, args: StartArgs, - router_port, - cache_port, - detokenization_pub_port, - visual_port, - metric_port, - enable_multimodal, ): self.args: StartArgs = args context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.multinode_req_manager = None self.nnodes = args.nnodes @@ -82,17 +76,21 @@ def __init__( f"HttpServerManager listening for child node requests on *:{args.multinode_httpmanager_port}" ) - self.enable_multimodal = enable_multimodal + self.enable_multimodal = args.enable_multimodal if self.enable_multimodal: - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + if args.enable_cpu_cache and not self.args.enable_multimodal: + self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) + self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") self.shm_req_manager = ShmReqManager() - self.recv_from_detokenization = context.socket(zmq.SUB) - self.recv_from_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_pub_port}") - self.recv_from_detokenization.setsockopt(zmq.SUBSCRIBE, b"") + # recv from detokenization + self.zmq_recv_socket = context.socket(zmq.SUB) + self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") + self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) @@ -100,7 +98,7 @@ def __init__( self.forwarding_queue: AsyncQueue = None # p d 分离模式使用的转发队列, 需要延迟初始化 self.max_req_total_len = args.max_req_total_len - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.NP, NodeRole.ND] @@ -496,27 +494,33 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode.is_P() or self.pd_mode.is_normal(): + if self.pd_mode.is_P_or_NORMAL(): if self.enable_multimodal: self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) - else: - self.send_to_router.send_pyobj( + return + + if self.args.enable_cpu_cache: + self.send_to_multi_level_kv_cache.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) - return + return - if self.pd_mode.is_D(): - # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) return + if self.pd_mode.is_D(): + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return assert False, "dead code path" @@ -562,6 +566,7 @@ async def _wait_to_token_package( metadata["prompt_ids"] = prompt_ids prompt_cache_len = metadata.pop("prompt_cache_len", 0) + cpu_prompt_cache_len = metadata.pop("cpu_prompt_cache_len", 0) if is_first_token: first_token_cost_ms = (time.time() - start_time) * 1000 is_first_token = False @@ -598,6 +603,8 @@ async def _wait_to_token_package( f"prompt_token_num:{prompt_tokens} " f"prompt_cache_len:{prompt_cache_len} " f"prompt_cache_ratio:{prompt_cache_ratio} " + f"cpu_prompt_cache_len:{cpu_prompt_cache_len} " + f"used_cpu_prompt_cache_len:{max(0, cpu_prompt_cache_len - prompt_cache_len)} " f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " ) if group_request_id < 0: @@ -689,7 +696,7 @@ async def handle_loop(self): while True: try: - await asyncio.wait_for(self.recv_from_detokenization.recv_pyobj(), timeout=0.05) + await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) except asyncio.TimeoutError: pass @@ -718,6 +725,7 @@ async def handle_loop(self): "special": special, "count_output_tokens": count_output_tokens, "prompt_cache_len": req.prompt_cache_len, + "cpu_prompt_cache_len": req.cpu_prompt_cache_len, "mtp_accepted_token_num": req.mtp_accepted_token_num, } if self.args.return_all_prompt_logprobs: diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 4e7349a05..4929a3a52 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -29,10 +29,9 @@ class HttpServerManagerForPDMaster: def __init__( self, args: StartArgs, - metric_port: int, ): self.args = args - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.id_gen = ReqIDGenerator() self.pd_manager = PDManager(args) diff --git a/lightllm/server/metrics/manager.py b/lightllm/server/metrics/manager.py index a9de4090b..f3b1a5275 100644 --- a/lightllm/server/metrics/manager.py +++ b/lightllm/server/metrics/manager.py @@ -8,6 +8,7 @@ import setproctitle from .metrics import Monitor from prometheus_client import generate_latest +from lightllm.server.core.objs import StartArgs from rpyc import SocketStream from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry @@ -135,7 +136,7 @@ def run(self): logger.error(f"monitor error {str(e)}") -def start_metric_manager(port: int, args, pipe_writer): +def start_metric_manager(args: StartArgs, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::metric_manager") @@ -147,6 +148,6 @@ def start_metric_manager(port: int, args, pipe_writer): from rpyc.utils.server import ThreadedServer - t = ThreadedServer(service, port=port) + t = ThreadedServer(service, port=args.metric_port) pipe_writer.send("init ok") t.start() diff --git a/lightllm/server/multi_level_kv_cache/__init__.py b/lightllm/server/multi_level_kv_cache/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/multi_level_kv_cache/cpu_cache_client.py b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py new file mode 100644 index 000000000..ba53179c5 --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/cpu_cache_client.py @@ -0,0 +1,284 @@ +import ctypes +import torch +import numpy as np +from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name +from typing import List, Optional, Tuple +from lightllm.utils.log_utils import init_logger +from .shm_objs import ShmDict, ShmLinkedList, _LinkedListItem, IntList +from lightllm.server.core.objs import AtomicShmLock +from lightllm.utils.kv_cache_utils import ( + calcu_cpu_cache_meta, + create_shm_kv_cache_ptr, + attach_shm_kv_cache_ptr, + register_shm_ptr_to_pin, +) + +logger = init_logger(__name__) + + +class CpuKvCacheClient(object): + """ + This class is responsible for handling cpu kv cache meta data. + """ + + def __init__(self, only_create_meta_data: bool, init_shm_data: bool): + self.args = get_env_start_args() + # to do here need calcu from from settings. + self.kv_cache_tensor_meta = calcu_cpu_cache_meta() + self.page_num: int = self.kv_cache_tensor_meta.page_num + self.lock = AtomicShmLock(lock_name=f"{get_unique_server_name()}_cpu_kv_cache_client_lock") + self._create_cpu_status_list(init_shm_data) + + if not only_create_meta_data: + if init_shm_data: + self._create_shm_cpu_kv_cache() + self.attach_shm_handle = None + else: + self.attach_shm_handle = self._attach_shm_cpu_kv_cache() + return + + def get_one_empty_page(self, hash_key: int, disk_offload_enable: bool) -> Optional[int]: + assert self.page_hash_dict.get(hash_key) is None + head = self.page_items.head + tail = self.page_items.tail + cur_page: _CpuPageStatus = head.get_next_item() + if cur_page.self_index == tail.self_index: + return None + + if cur_page.can_realloc(disk_offload_enable=disk_offload_enable): + page_index = cur_page.self_index + cur_page.del_self_from_list() + if not cur_page.is_empty(): + self.page_hash_dict.remove(cur_page.hash_key) + cur_page.hash_key = hash_key + cur_page.status = cur_page.LOADING + cur_page.ref_count += 1 + self.page_hash_dict.put(hash_key, page_index) + self.page_items.add_item_to_tail(cur_page.self_index) + return page_index + else: + return None + + def allocate_one_page(self, hash_key: int, disk_offload_enable: bool) -> Tuple[Optional[int], bool]: + page_index = self.page_hash_dict.get(hash_key) + if page_index is not None: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if page_item.is_data_ready(): + page_item.ref_count += 1 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, True + else: + page_item.ref_count += 1 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, False + else: + page_index = self.get_one_empty_page(hash_key=hash_key, disk_offload_enable=disk_offload_enable) + if page_index is not None: + return page_index, False + else: + return None, False + + def allocate_pages(self, hash_keys: List[int], disk_offload_enable: bool) -> Tuple[List[int], List[bool]]: + """ + allocate_pages will add _CpuPageStaus ref_count + """ + page_list = [] + ready_list = [] + for hash_key in hash_keys: + page_index, ready = self.allocate_one_page(hash_key=hash_key, disk_offload_enable=disk_offload_enable) + if page_index is not None: + page_list.append(page_index) + ready_list.append(ready) + else: + page_list.append(-1) + ready_list.append(False) + break + + left_num = len(hash_keys) - len(page_list) + page_list.extend([-1 for _ in range(left_num)]) + ready_list.extend([False for _ in range(left_num)]) + return page_list, ready_list + + def update_pages_status_to_ready(self, page_list: List[int], deref: bool = True, disk_offload_enable: bool = False): + for page_index in page_list: + if page_index != -1: + cur_page: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if cur_page.status < cur_page.READY: + cur_page.status = cur_page.READY + if disk_offload_enable: + self.offload_page_indexes.add_item(value=cur_page.self_index) + if deref: + assert cur_page.ref_count > 0 + cur_page.ref_count -= 1 + return + + def query_one_page(self, hash_key: int) -> Tuple[Optional[int], bool]: + page_index = self.page_hash_dict.get(hash_key) + if page_index is not None: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if page_item.is_data_ready(): + page_item.ref_count += 1 + # lru 更新 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return page_index, True + else: + # lru 更新 + page_item.del_self_from_list() + self.page_items.add_item_to_tail(index=page_index) + return None, False + else: + return None, False + + def check_allpages_ready(self, page_list: List[int]) -> bool: + for page_index in page_list: + if page_index == -1: + continue + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + if not page_item.is_data_ready(): + return False + return True + + def deref_pages(self, page_list: List[int]): + """ + deref_pages + """ + for page_index in page_list: + if page_index != -1: + self.deref_one_page(page_index=page_index) + return + + def deref_one_page(self, page_index: int): + page_item: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + assert page_item.ref_count > 0 + page_item.ref_count -= 1 + return + + def get_pages_to_offloading(self) -> List[int]: + page_list = self.offload_page_indexes.pop_all_item() + ans_list = [] + if page_list is not None: + for page_index in page_list: + page_item: _CpuPageStatus = self.page_items.get_item_by_index(index=page_index) + if page_item.is_ready(): + page_item.ref_count += 1 + page_item.status = page_item.OFFLOADING + ans_list.append(page_index) + return ans_list + + def update_pages_status_to_ready_recycle(self, page_list: List[int], deref: bool = True): + for page_index in page_list: + if page_index != -1: + cur_page: _CpuPageStatus = self.page_items.get_item_by_index(page_index) + assert cur_page.is_offloading() + cur_page.status = cur_page.READY_RECYCLE + if deref: + assert cur_page.ref_count > 0 + cur_page.ref_count -= 1 + return + + def _create_cpu_status_list(self, init_shm_data: bool): + self.page_items = ShmLinkedList( + name=f"{get_unique_server_name()}_cpu_kv_cache_page_items", + item_class=_CpuPageStatus, + capacity=self.page_num, + init_shm_data=init_shm_data, + ) + self.page_hash_dict = ShmDict( + name=f"{get_unique_server_name()}_cpu_kv_cache_hash", + capacity=self.page_num * 2, + init_shm_data=init_shm_data, + ) + self.offload_page_indexes = IntList( + name=f"{get_unique_server_name()}_cpu_kv_cache_offload_page_indexes", + capacity=self.page_num, + init_shm_data=init_shm_data, + ) + return + + def _create_shm_cpu_kv_cache(self): + shm_ptr = create_shm_kv_cache_ptr() + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 + ) + # 将 NumPy 数组转换为 PyTorch 张量 + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.head_dim, + ) + self.cpu_kv_cache_tensor = torch.from_numpy(numpy_array).view(dtype=torch.bfloat16).view(shape) + return + + def _attach_shm_cpu_kv_cache(self): + shm_ptr = attach_shm_kv_cache_ptr() + handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size()) + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 + ) + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.head_dim, + ) + self.cpu_kv_cache_tensor = torch.from_numpy(numpy_array).view(dtype=torch.bfloat16).view(shape) + assert shm_ptr == self.cpu_kv_cache_tensor.data_ptr() + + # test code + # self.cpu_kv_cache_tensor = torch.zeros_like(self.cpu_kv_cache_tensor, device="cpu", pin_memory=True) + # self.cpu_kv_cache_tensor = torch.zeros_like(self.cpu_kv_cache_tensor, device="cuda") + return handle + + +class _CpuPageStatus(_LinkedListItem): + _pack_ = 4 + _fields_ = [("status", ctypes.c_int), ("ref_count", ctypes.c_int), ("hash_key", ctypes.c_uint64)] + + EMPTY = 0 # 空闲 + LOADING = 1 # 从 gpu buffer 加载到 cpu 的状态,或者是从磁盘加载到 cpu 的状态 + READY = 2 # 数据已经加载到 cpu ok 的状态 + OFFLOADING = 3 # 从 cpu 卸载到 硬盘的状态 + READY_RECYCLE = 4 # 因为卸载到硬盘已经完成,所以可以进行回收使用 + + def __init__(self): + self.init() + + def init(self): + super().init() + self.ref_count = 0 + self.status = self.EMPTY + self.hash_key = 0 + return + + def is_empty(self): + return self.status == self.EMPTY + + def is_loading(self): + return self.status == self.LOADING + + def is_ready(self): + return self.status == self.READY + + def is_offloading(self): + return self.status == self.OFFLOADING + + def is_ready_recycle(self): + return self.status == self.READY_RECYCLE + + def is_data_ready(self): + """ + 判断数据是否是填充ok的,可能包含多种状态下属于数据是可填充的状态。 + """ + return self.status >= self.READY + + def can_realloc(self, disk_offload_enable: bool): + if disk_offload_enable: + return (self.is_empty() or self.is_ready_recycle()) and self.ref_count == 0 + else: + return (self.is_empty() or self.is_data_ready()) and self.ref_count == 0 diff --git a/lightllm/server/multi_level_kv_cache/manager.py b/lightllm/server/multi_level_kv_cache/manager.py new file mode 100644 index 000000000..8853e352e --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/manager.py @@ -0,0 +1,154 @@ +import uvloop +import asyncio + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +import zmq +import inspect +import pickle +import time +import threading +import concurrent.futures +from queue import Queue +from lightllm.server.core.objs import ShmReqManager, Req, StartArgs +from lightllm.server.core.objs.io_objs import GroupReqIndexes +from lightllm.utils.graceful_utils import graceful_registry +from .cpu_cache_client import CpuKvCacheClient +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MultiLevelKVCacheManager: + def __init__( + self, + args: StartArgs, + ): + self.args: StartArgs = args + context = zmq.Context(2) + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + + self.send_to_router = context.socket(zmq.PUSH) + self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + logger.info(f"send_to_router sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}") + self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=True) + self.shm_req_manager = ShmReqManager() + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=6) + # 控制进行 cpu cache 页面匹配的时间,超过时间则不再匹配,直接转发。 + self.cpu_cache_time_out = 0.5 + self.recv_queue = Queue(maxsize=1024) + self.cpu_cache_thread = threading.Thread(target=self.cpu_cache_hanle_loop, daemon=True) + self.cpu_cache_thread.start() + return + + def cpu_cache_hanle_loop(self): + while True: + try: + current_group_req = self.recv_queue.get() + + self.executor.submit(self._handle_group_req_cpu_cache_match, current_group_req, time.time()) + except BaseException as e: + logger.exception(str(e)) + return + + def _handle_group_req_cpu_cache_match(self, group_req_indexes: GroupReqIndexes, start_time: float): + """ + match cpu cache pages + """ + # 超时时,放弃进行 cpu cache page 的匹配。 + current_time = time.time() + if current_time - start_time >= self.cpu_cache_time_out: + self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + logger.warning( + f"cpu cache match time out {current_time - start_time}s, group_req_id: {group_req_indexes.group_req_id}" + ) + return + + reqs_shm_index = group_req_indexes.shm_req_indexes + reqs = [self.shm_req_manager.get_req_obj_by_index(index) for index in reqs_shm_index] + + # 对每个请求进行cpu cache page 的匹配操作。 + for req in reqs: + # diverse_mode 只有主请求一个初始化 cpu cache 信息。 + if self.args.diverse_mode and req.request_id != req.group_req_id: + continue + + if req.is_aborted: + continue + + req: Req = req + finded_page_indexes = [] + for token_chuncked_hash_value in req.token_hash_list.get_all(): + self.cpu_cache_client.lock.acquire_sleep1ms() + page_index, ready = self.cpu_cache_client.query_one_page(token_chuncked_hash_value) + self.cpu_cache_client.lock.release() + + if page_index is not None: + assert ready + finded_page_indexes.append(page_index) + else: + break + + # 等待所有的 cpu cache 页面ready + while not self.cpu_cache_client.check_allpages_ready(finded_page_indexes): + time.sleep(0.01) + + req.cpu_cache_match_page_indexes.fill(finded_page_indexes) + + for req in reqs: + self.shm_req_manager.put_back_req_obj(req) + + self.send_to_router.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + return + + def recv_loop(self): + try: + recv_max_count = 128 + + while True: + recv_objs = [] + try: + # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 + for _ in range(recv_max_count): + recv_obj: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + assert isinstance(recv_obj, GroupReqIndexes) + recv_objs.append(recv_obj) + + start_time = recv_obj.time_mark + logger.info( + f"multi_level_kv_cache recive group req id {recv_obj.group_req_id} " + f"cost time {time.time() - start_time} s" + ) + + # 当队列中存在较多的请求时,将一次接受的数量上调 + recv_max_count = min(int(recv_max_count * 1.3), 256) + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受的数量下调 + recv_max_count = 128 + + for recv_obj in recv_objs: + self.recv_queue.put(recv_obj) + + if len(recv_objs) == 0: + time.sleep(0.01) + + except Exception as e: + logger.exception(f"detoken process has exception {str(e)}") + return + + +def start_multi_level_kv_cache_manager(args, pipe_writer): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + + try: + manager = MultiLevelKVCacheManager( + args=args, + ) + except Exception as e: + pipe_writer.send(str(e)) + raise + + pipe_writer.send("init ok") + manager.recv_loop() + return diff --git a/lightllm/server/multi_level_kv_cache/shm_objs.py b/lightllm/server/multi_level_kv_cache/shm_objs.py new file mode 100644 index 000000000..a28ec9fba --- /dev/null +++ b/lightllm/server/multi_level_kv_cache/shm_objs.py @@ -0,0 +1,273 @@ +import ctypes +import numpy as np +from multiprocessing import shared_memory +from typing import List, Optional +from lightllm.utils.log_utils import init_logger +from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup + +logger = init_logger(__name__) + + +class IntList(object): + def __init__(self, name: str, capacity: int, init_shm_data: bool): + self.capacity: int = capacity + byte_size = np.dtype(np.int32).itemsize * (self.capacity + 1) + shm_name = name + shm = _create_shm(name=shm_name, byte_size=byte_size) + self.shm = shm + + if self.shm.size != byte_size: + logger.info(f"size not same, unlink lock shm {self.shm.name} and create again") + self.shm.close() + self.shm.unlink() + self.shm = None + self.shm = _create_shm(name=shm_name, byte_size=byte_size) + + self.arr = np.ndarray((self.capacity + 1), dtype=np.int32, buffer=self.shm.buf) + if init_shm_data: + self.arr.fill(0) + return + + def size(self): + return self.arr[-1] + + def add_item(self, value: int): + write_index = self.arr[-1] + self.arr[write_index] = value + self.arr[-1] += 1 + return + + def pop_all_item(self) -> Optional[List[int]]: + if self.size() == 0: + return None + + ans = self.arr[0 : self.size()].tolist() + self.arr[-1] = 0 + return ans + + +class ShmLinkedList(object): + def __init__(self, name: str, item_class: "_LinkedListItem.__class__", capacity: int, init_shm_data: bool): + self.capacity: int = capacity + # add head and tail node. + byte_size = ctypes.sizeof(item_class) * (self.capacity + 2) + shm_name = name + shm = _create_shm(name=shm_name, byte_size=byte_size) + self.shm = shm + + if self.shm.size != byte_size: + logger.info(f"size not same, unlink lock shm {self.shm.name} and create again") + self.shm.close() + self.shm.unlink() + self.shm = None + self.shm = _create_shm(name=shm_name, byte_size=byte_size) + # 构建 hash table 表 + self.linked_items: List[_LinkedListItem] = (item_class * (self.capacity + 2)).from_buffer(self.shm.buf) + # 如果不转变存储,set_list_obj 的对象上绑定的非shm信息在下一次从 shm 中获取对象时将丢失 + self.linked_items = [item for item in self.linked_items] + for e in self.linked_items: + e.set_list_obj(self) + + self.head = self.linked_items[self.capacity] + self.tail = self.linked_items[self.capacity + 1] + + if init_shm_data: + for e in self.linked_items: + e.init() + + self.head.self_index = self.capacity + self.tail.self_index = self.capacity + 1 + self.head.next_index = self.tail.self_index + self.tail.pre_index = self.head.self_index + + for i in range(self.capacity): + item = self.linked_items[i] + item.self_index = i + self.add_item_to_tail(i) + return + + def add_item_to_tail(self, index: int): + item = self.linked_items[index] + pre_node = self.linked_items[self.tail.pre_index] + pre_node.next_index = item.self_index + item.pre_index = pre_node.self_index + item.next_index = self.tail.self_index + self.tail.pre_index = item.self_index + return + + def get_item_by_index(self, index: int) -> "_LinkedListItem": + item = self.linked_items[index] + return item + + def pop_head_item(self) -> "_LinkedListItem": + head_item = self.linked_items[self.head.next_index] + if head_item.self_index == self.tail.self_index: + return None + head_item.del_self_from_list() + return head_item + + +class ShmDict(object): + def __init__(self, name: str, capacity: int, init_shm_data: bool): + self.capacity: int = capacity + self.link_items: ShmLinkedList = ShmLinkedList( + name=name, item_class=_HashLinkItem, capacity=self.capacity * 2, init_shm_data=init_shm_data + ) + # 将前capacity个item,作为hash item的链表头。 + if init_shm_data: + for i in range(self.capacity): + self.link_items.pop_head_item() + item: _HashLinkItem = self.link_items.get_item_by_index(i) + item.pre_index = -1 + item.next_index = -1 + return + + def put(self, key: int, value: int): + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: # 空的 + add_link_item: _HashLinkItem = self.link_items.pop_head_item() + add_link_item.key = key + add_link_item.value = value + hash_item.next_index = add_link_item.self_index + add_link_item.pre_index = hash_item.self_index + add_link_item.next_index = -1 + return + + # 存在元素,先遍历是否已经存在 + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + # 找到对应key的元素,并设置对应的value + while True: + if cur_link_item.key == key: + cur_link_item.value = value + return + else: + next_item = cur_link_item.get_next_item() + if next_item is None: + break + else: + cur_link_item = next_item + + # 没有找到时候,直接插入一个新的节点 + add_link_item: _HashLinkItem = self.link_items.pop_head_item() + add_link_item.key = key + add_link_item.value = value + + cur_link_item.next_index = add_link_item.self_index + add_link_item.pre_index = cur_link_item.self_index + add_link_item.next_index = -1 + return + + def get(self, key: int) -> Optional[int]: + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: + return None + else: + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + # 找到对应key的元素,并设置对应的value + while cur_link_item is not None: + if cur_link_item.key == key: + return cur_link_item.value + else: + cur_link_item = cur_link_item.get_next_item() + return None + + def remove(self, key: int): + dest_index = key % self.capacity + hash_item: _HashLinkItem = self.link_items.get_item_by_index(dest_index) + if hash_item.next_index == -1: + logger.warning(f"shm dict not contain key {key}") + return + + start_link_item: _HashLinkItem = hash_item.get_next_item() + cur_link_item = start_link_item + + # 找到对应key的元素,并设置对应的value + while cur_link_item is not None: + if cur_link_item.key == key: + break + else: + cur_link_item = cur_link_item.get_next_item() + + if cur_link_item is not None: + # remove item + pre_item = cur_link_item.get_pre_item() + pre_item.next_index = cur_link_item.next_index + if cur_link_item.next_index != -1: + next_item = cur_link_item.get_next_item() + next_item.pre_index = pre_item.self_index + + self.link_items.add_item_to_tail(index=cur_link_item.self_index) + else: + logger.warning(f"shm dict not contain key {key}") + return + + +class _LinkedListItem(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("self_index", ctypes.c_int), + ("pre_index", ctypes.c_int), + ("next_index", ctypes.c_int), + ] + + def __init__(self): + self.init() + + def init(self): + self.self_index = -1 + self.pre_index = -1 + self.next_index = -1 + return + + def set_list_obj(self, parent_list: ShmLinkedList): + self.linked_items = parent_list.linked_items + return + + def get_next_item(self) -> "_LinkedListItem": + if self.next_index == -1: + return None + return self.linked_items[self.next_index] + + def get_pre_item(self) -> "_LinkedListItem": + if self.pre_index == -1: + return None + return self.linked_items[self.pre_index] + + def del_self_from_list(self): + pre_node = self.get_pre_item() + next_node = self.get_next_item() + pre_node.next_index = next_node.self_index + next_node.pre_index = pre_node.self_index + return + + +class _HashLinkItem(_LinkedListItem): + _pack_ = 4 + _fields_ = [ + ("key", ctypes.c_uint64), + ("value", ctypes.c_int), + ] + + def __init__(self): + self.init() + + def init(self): + super().init() + self.key = 0 + self.value = -1 + + +def _create_shm(name: str, byte_size: int, auto_cleanup: bool = False): + try: + shm = shared_memory.SharedMemory(name=name, create=True, size=byte_size) + if auto_cleanup: + register_posix_shm_for_cleanup(name) + logger.info(f"create lock shm {name}") + except: + shm = shared_memory.SharedMemory(name=name, create=False, size=byte_size) + logger.info(f"link lock shm {name}") + return shm diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 65ec4354b..28c4ceb1e 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -1,10 +1,10 @@ # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py import torch import numpy as np -from typing import Tuple, Dict, Set, List +import collections +from typing import Tuple, Dict, Set, List, Optional, Union from sortedcontainers import SortedSet from .shared_arr import SharedArray -from lightllm.common.mem_manager import MemoryManager class UniqueTimeIdGenerator: @@ -103,8 +103,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None): - self.mem_manager = mem_manager + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.common.mem_manager import MemoryManager + + self.mem_manager: MemoryManager = mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -123,67 +125,110 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: Memo ) self.tree_total_tokens_num.arr[0] = 0 - def insert(self, key, value=None): + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key assert len(key) == len(value) # and len(key) >= 1 if len(key) == 0: - return 0 + return 0, None return self._insert_helper(self.root_node, key, value) - def _insert_helper(self, node: TreeNode, key, value): + def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key, value)) + + ans_prefix_len = 0 + ans_node = None + + while len(handle_stack) != 0: + node, key, value = handle_stack.popleft() + ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value) + if len(ans_tuple) == 4: + (_prefix_len, new_node, new_key, new_value) = ans_tuple + ans_prefix_len += _prefix_len + handle_stack.append((new_node, new_key, new_value)) + else: + _prefix_len, ans_node = ans_tuple + ans_prefix_len += _prefix_len + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + assert ans_node is not None + + return ans_prefix_len, ans_node + + def _insert_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, value: torch.Tensor + ) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]: if node.is_leaf(): self.evict_tree_set.discard(node) - try: - first_key_id = key[0].item() - if first_key_id in node.children.keys(): - child: TreeNode = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) - if prefix_len == len(key): + first_key_id = key[0].item() + if first_key_id in node.children.keys(): + child: TreeNode = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(key): + if prefix_len == len(child.token_id_key): if child.is_leaf(): self.evict_tree_set.discard(child) child.update_time() if child.is_leaf(): self.evict_tree_set.add(child) - return prefix_len - - elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + return prefix_len, child + elif prefix_len < len(child.token_id_key): if child.is_leaf(): self.evict_tree_set.discard(child) - key = key[prefix_len:] - value = value[prefix_len:] split_parent_node = child.split_node(prefix_len) - new_node = split_parent_node.add_and_return_new_child(key, value) - # update total token num - self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) if split_parent_node.is_leaf(): self.evict_tree_set.add(split_parent_node) - if new_node.is_leaf(): - self.evict_tree_set.add(new_node) - if child.is_leaf(): self.evict_tree_set.add(child) - return prefix_len - elif prefix_len < len(key) and prefix_len == len(child.token_id_key): - return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + + return prefix_len, split_parent_node else: assert False, "can not run to here" - else: - new_node = node.add_and_return_new_child(key, value) + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + key = key[prefix_len:] + value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(key, value) # update total token num self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) if new_node.is_leaf(): self.evict_tree_set.add(new_node) - return 0 - finally: - node.update_time() - if node.is_leaf(): - self.evict_tree_set.add(node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len, new_node + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return (prefix_len, child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + else: + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0, new_node def match_prefix(self, key, update_refs=False): assert len(key) != 0 @@ -199,7 +244,39 @@ def match_prefix(self, key, update_refs=False): self.dec_node_ref_counter(self.root_node) return None, 0, None - def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + def _match_prefix_helper( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: + handle_stack = collections.deque() + update_list = collections.deque() + handle_stack.append((node, key)) + + ans_node = None + + while len(handle_stack) != 0: + node, key = handle_stack.popleft() + ans_tuple = self._match_prefix_helper_no_recursion( + node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs + ) + if isinstance(ans_tuple, tuple): + new_node, new_key = ans_tuple + handle_stack.append((new_node, new_key)) + else: + ans_node = ans_tuple + + update_list.append(node) + + while len(update_list) != 0: + cur_node: TreeNode = update_list.pop() + cur_node.update_time() + if cur_node.is_leaf(): + self.evict_tree_set.add(cur_node) + + return ans_node + + def _match_prefix_helper_no_recursion( + self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False + ) -> TreeNode: if node.is_leaf(): self.evict_tree_set.discard(node) @@ -209,44 +286,39 @@ def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update if node.ref_counter == 1: self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) - try: - if len(key) == 0: - return node + if len(key) == 0: + return node - first_key_id = key[0].item() - if first_key_id not in node.children.keys(): - return node + first_key_id = key[0].item() + if first_key_id not in node.children.keys(): + return node + else: + child = node.children[first_key_id] + prefix_len = match(key, child.token_id_key) + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return (child, key[prefix_len:]) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node else: - child = node.children[first_key_id] - prefix_len = match(key, child.token_id_key) - if prefix_len == len(child.token_id_key): - ans_value_list.append(child.token_mem_index_value) - return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) - elif prefix_len < len(child.token_id_key): - if child.is_leaf(): - self.evict_tree_set.discard(child) - - split_parent_node = child.split_node(prefix_len) - ans_value_list.append(split_parent_node.token_mem_index_value) - - if update_refs: - split_parent_node.ref_counter += 1 - # from 0 to 1 need update refs token num - if split_parent_node.ref_counter == 1: - self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) - - if child.is_leaf(): - self.evict_tree_set.add(child) - if split_parent_node.is_leaf(): - self.evict_tree_set.add(split_parent_node) - - return split_parent_node - else: - assert False, "error state" - finally: - node.update_time() - if node.is_leaf(): - self.evict_tree_set.add(node) + assert False, "error state" def evict(self, need_remove_tokens, evict_callback): if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: @@ -313,6 +385,37 @@ def dec_node_ref_counter(self, node: TreeNode): self.evict_tree_set.add(old_node) return + def add_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_mem_index_value_by_node(self, node: TreeNode) -> Optional[torch.Tensor]: + if node is None: + return None + + ans_list = [] + while node is not None: + ans_list.append(node.token_mem_index_value) + node = node.parent + + ans_list.reverse() + return torch.concat(ans_list, dim=0) + def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 38dd21155..3c8ca2399 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -23,6 +23,7 @@ ) from lightllm.server.core.objs import ShmReqManager, StartArgs from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient +from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready from lightllm.server.router.token_load import TokenLoad @@ -38,7 +39,7 @@ class RouterManager: - def __init__(self, args: StartArgs, router_port, detokenization_port, metric_port): + def __init__(self, args: StartArgs): self.args = args self.model_weightdir = args.model_dir self.world_size = args.tp @@ -76,11 +77,11 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por self.running_batch: Batch = None context = zmq.Context(2) - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}") + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.send_to_detokenization = context.socket(zmq.PUSH) - self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}") + self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{args.detokenization_port}") if self.is_multinode_tp: self.mulitnode_group = dist.init_process_group( @@ -90,7 +91,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por rank=args.node_rank, ) - self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(args.metric_port) self.is_pd_run_mode = self.args.run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"] self.is_pd_decode_mode = self.args.run_mode in ["decode", "nixl_decode"] # p d 分离模式下,需要调度锁来同步调度端和推理端的一些数据操作 @@ -99,6 +100,12 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por g_router_lock.obj = self.router_lock self.shm_reqs_io_buffer = ShmObjsIOBuffer() + + self.cpu_cache_client = ( + None + if not self.args.enable_cpu_cache + else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) + ) return async def wait_to_model_ready(self): @@ -508,7 +515,7 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) else: @@ -532,7 +539,7 @@ def clean_up(self): return -def start_router_process(args, router_port, detokenization_port, metric_port, pipe_writer): +def start_router_process(args, pipe_writer): # 注册 graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::router_server") @@ -547,10 +554,7 @@ def handle_exception(loop, context): try: router = RouterManager( - args, - router_port=router_port, - detokenization_port=detokenization_port, - metric_port=metric_port, + args=args, ) loop.run_until_complete(router.wait_to_model_ready()) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 643c317bd..3fe3f5136 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -1,6 +1,4 @@ -import os -import copy -import time +import enum import torch import torch.distributed as dist import numpy as np @@ -30,22 +28,25 @@ class InferenceContext: radix_cache: RadixCache = None shm_req_manager: ShmReqManager = None # 共享内存请求对象管理 requests_mapping: Dict[int, "InferReq"] = None - group_mapping = None # 只有进行多输出模式下才有真的使用 infer_req_ids = None vocab_size = None overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 + cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream def register( - self, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int + self, backend, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int ): + self.args = get_env_start_args() + from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend + + self.backend: ModeBackend = backend self.req_manager = req_manager self.req_sampling_manager = self.req_manager.req_sampling_params_manager self.radix_cache = radix_cache self.shm_req_manager = shm_req_manager self.requests_mapping = {} - self.group_mapping: Dict[int, InferReqGroup] = {} self.infer_req_ids = [] self.vocab_size = vocab_size @@ -56,6 +57,11 @@ def get_overlap_stream(self) -> torch.cuda.Stream: self.overlap_stream = torch.cuda.Stream() return self.overlap_stream + def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: + if self.cpu_kv_cache_stream is None: + self.cpu_kv_cache_stream = torch.cuda.Stream() + return self.cpu_kv_cache_stream + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -76,46 +82,42 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: self.infer_req_ids.extend(request_ids) - # 多输出模式下需要将请求添加到各自的组对象 InferReqGroup 中 + # diverse mode 下,建立一组请求间的主从关系 if get_env_start_args().diverse_mode: + group_reqs: Dict[int, InferReq] = collections.defaultdict(lambda: [None, list()]) for r_id in request_ids: req: InferReq = g_infer_context.requests_mapping[r_id] group_req_id = req.shm_req.group_req_id - if group_req_id not in g_infer_context.group_mapping: - g_infer_context.group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id) - g_infer_context.group_mapping[group_req_id].add_req(r_id) + if req.req_id == group_req_id: + group_reqs[group_req_id][0] = req + else: + group_reqs[group_req_id][1].append(req) + + for group_req_id, (master_req, slave_reqs) in group_reqs.items(): + master_req: InferReq = master_req + master_req.slave_reqs.extend(slave_reqs) + for slave_req in slave_reqs: + slave_req: InferReq = slave_req + slave_req.related_master_req = master_req return req_objs - def free_a_req_mem(self, free_token_index: List, req: "InferReq", is_group_finished: bool): + def free_a_req_mem(self, free_token_index: List, req: "InferReq"): if self.radix_cache is None: - if is_group_finished: - free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) - else: - free_token_index.append( - self.req_manager.req_to_token_indexs[req.req_idx][req.shm_req.input_len : req.cur_kv_len] - ) + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) else: input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - if is_group_finished: - prefix_len = self.radix_cache.insert(key, value) - old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len - free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) - if req.shared_kv_node is not None: - assert req.shared_kv_node.node_prefix_total_len <= prefix_len - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - req.shared_kv_node = None - else: - free_token_index.append( - self.req_manager.req_to_token_indexs[req.req_idx][req.shm_req.input_len : req.cur_kv_len] - ) - if req.shared_kv_node is not None: - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - req.shared_kv_node = None + prefix_len, _ = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None def _save_promptcache_kvbuffer(self): """ @@ -140,14 +142,10 @@ def _filter(self, finished_request_ids: List[int]): free_token_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) - group_req_id = convert_sub_id_to_group_id(req.shm_req.request_id) - if group_req_id in self.group_mapping: - is_group_finished = self.group_mapping[group_req_id].remove_req(req.shm_req.request_id) - if is_group_finished: - del self.group_mapping[group_req_id] - self.free_a_req_mem(free_token_index, req, is_group_finished) - else: - self.free_a_req_mem(free_token_index, req, True) + if self.args.diverse_mode: + req.clear_master_slave_state() + self.free_a_req_mem(free_token_index, req) + free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True @@ -184,8 +182,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = [] for req in pause_reqs: - # 不支持多输出的情况的暂停, 不能支持 diverse 输出模式。 - self.free_a_req_mem(free_token_index, req, is_group_finished=True) + if self.args.diverse_mode: + # 发生暂停的时候,需要清除 diverse 模式下的主从关系 + req.clear_master_slave_state() + self.free_a_req_mem(free_token_index, req) req.cur_kv_len = 0 req.shm_req.shm_cur_kv_len = req.cur_kv_len assert req.wait_pause is True @@ -290,6 +290,20 @@ def has_constraint_setting(self) -> bool: class InferReq: + class _CpuCacheTaskStatus(enum.Enum): + NOT_STARTED = 0 + RUNNING = 1 + FINISHED = 2 + + def is_not_started(self): + return self == self.NOT_STARTED + + def is_running(self): + return self == self.RUNNING + + def is_finished(self): + return self == self.FINISHED + def __init__( self, req_id: int, @@ -315,6 +329,10 @@ def __init__( self.need_out_token_id_statistics = True self.out_token_id_count: Dict[int, int] = None + # diverse mode 下,用于标记请求组之间的依赖关系 + self.slave_reqs: List[InferReq] = [] + self.related_master_req: InferReq = None + # nixl pd 分离模式使用的变量, 普通模式下这些变量没有具体用途 self.nixl_trans_kv_start_index: int = 0 self.nixl_pd_task_num: int = 0 @@ -322,6 +340,10 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache + # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 + self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step @@ -381,6 +403,37 @@ def _match_radix_cache(self): self.shm_req.shm_cur_kv_len = self.cur_kv_len return + def is_master_req(self): + """ + diverse 模式下,判断当前请求是否为独立主请求,其进行prefill后,将 + kv 通过 radix cache 共享给其他 slave 请求, 共享后 slave 请求也 + 会升级为 master 请求,具有独立推理,暂停的特性。 + """ + return self.related_master_req is None + + def is_slave_req(self): + return self.related_master_req is not None + + def clear_master_slave_state(self): + if self.is_slave_req(): + self.remove_master_req() + elif self.is_master_req(): + # 数组需要 copy 后遍历。 + for slave_req in self.slave_reqs.copy(): + slave_req.remove_master_req() + + def remove_master_req(self): + """ + 一个处于 slave 状态的请求,解除与 master 请求的依赖关系后,自己会升级为 + master_req 的状态,具有独立推理,暂停的特性。 + """ + master_req = self.related_master_req + if master_req is not None: + master_req.slave_reqs.remove(self) + self.related_master_req = None + else: + logger.warning(f"try to remove master req, but related_master_req is None, req id {self.req_id}") + def get_output_len(self): return self.cur_output_len @@ -456,52 +509,6 @@ def _mtp_decode_need_token_num(self) -> int: return (1 + self.mtp_step) * 2 -class InferReqGroup: - def __init__( - self, - group_req_id: int, - ) -> None: - self.group_req_id = group_req_id - self.req_ids_group = [] - - def get_req(self, index): - return g_infer_context.requests_mapping[self.req_ids_group[index]] - - def get_all_reqs(self): - return [g_infer_context.requests_mapping[self.req_ids_group[i]] for i in range(len(self.req_ids_group))] - - def add_req(self, req_id): - self.req_ids_group.append(req_id) - - def remove_req(self, req_id): - assert req_id in self.req_ids_group - self.req_ids_group.remove(req_id) - return len(self.req_ids_group) == 0 - - def best_of(self): - return len(self.req_ids_group) - - def diverse_copy(self, req_manager, is_prefill): - # record previous status - prev_req = g_infer_context.requests_mapping[convert_sub_id_to_group_id(self.req_ids_group[0])] - if prev_req.shared_kv_node is not None: - prefix_len = prev_req.shared_kv_node.node_prefix_total_len - else: - prefix_len = 0 - prefix_len = max(prefix_len, prev_req.cur_kv_len) - pre_input_len = prev_req.get_chuncked_input_token_len() - cache_token_id = req_manager.req_to_token_indexs[prev_req.req_idx][prefix_len:pre_input_len] - # update the InferReq status and mem_manager status for cache sharing - for req_id in self.req_ids_group[:]: - if req_id == convert_sub_id_to_group_id(req_id): - continue - req = g_infer_context.requests_mapping[req_id] - req.finish_status.set_status(FinishStatus.NO_FINISH) - input_len = req.get_chuncked_input_token_len() - req_manager.req_to_token_indexs[req.req_idx][prefix_len:input_len] = cache_token_id - assert input_len == pre_input_len - - class InferReqUpdatePack: """ 用于延迟InferReq的请求更新,主要是为了方便更高效的overlap机制实现。解耦 diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 07dfc19fa..b6cb4d21f 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -34,6 +34,7 @@ from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet +from .multi_level_kv_cache import MultiLevelKvCacheModule class ModeBackend: @@ -122,6 +123,11 @@ def init_model(self, kvargs): # 所以做一次barrier等待 dist.barrier() + wait_events = [] + if self.args.enable_cpu_cache: + self.multi_level_cache_module = MultiLevelKvCacheModule(self) + wait_events.append(self.multi_level_cache_module) + model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) model_kvargs = { @@ -143,6 +149,7 @@ def init_model(self, kvargs): "quant_type": kvargs.get("quant_type", None), "quant_cfg": kvargs.get("quant_cfg", None), "run_mode": self.run_mode, + "wait_events": wait_events, } self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing @@ -164,6 +171,7 @@ def init_model(self, kvargs): self.logger.info(f"loaded model class {self.model.__class__}") g_infer_context.register( + backend=self, req_manager=self.model.req_manager, radix_cache=self.radix_cache, shm_req_manager=self.shm_req_manager, @@ -359,7 +367,9 @@ def _read_reqs_buffer_and_init_reqs(self): else: assert False, f"error type {type(obj)}" if init_reqs: - self._init_reqs(reqs=init_reqs) + req_ids = self._init_reqs(reqs=init_reqs) + if self.args.enable_cpu_cache and req_ids: + self._load_cpu_cache_to_reqs(req_ids=req_ids) return def _read_nixl_trans_io_buffer_and_update_req_status(self): @@ -414,6 +424,13 @@ def _init_reqs(self, reqs: List[Tuple]): req_ids = [e[0] for e in reqs] return req_ids + def _load_cpu_cache_to_reqs(self, req_ids): + req_objs: List[InferReq] = [g_infer_context.requests_mapping[req_id] for req_id in req_ids] + g_infer_state_lock.acquire() + self.multi_level_cache_module.load_cpu_cache_to_reqs(reqs=req_objs) + g_infer_state_lock.release() + return + def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: """ 将错误请求从 req_ids 中过滤出来, 然后让 _get_classed_reqs 进行处理。 该函数 @@ -448,6 +465,8 @@ def _get_classed_reqs( 4. prefill_reqs 需要进行prefill操作的请求 5. decode_reqs 需要进行decode操作的请求 """ + if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0: + self.multi_level_cache_module.update_cpu_cache_task_states() if req_ids is None: req_ids = g_infer_context.infer_req_ids @@ -517,6 +536,12 @@ def _get_classed_reqs( req_obj.wait_pause = True wait_pause_count += 1 else: + # 在 diverse mode 模式下,prefill 只会使用 master 状态的请求,slave 请求依靠后续 + # 的推理代码中将master请求的状态复制到slave请求中去, 所以这里 slave 状态的请求,不 + # 放入到 prefill reqs 队列中,在其他模式下,所有请求都是 master状态,所以也不受影响 + if req_obj.is_slave_req(): + continue + token_num = req_obj.prefill_need_token_num(is_chuncked_prefill=not self.disable_chunked_prefill) if prefill_tokens + token_num > self.batch_max_tokens: continue @@ -532,8 +557,15 @@ def _get_classed_reqs( g_infer_state_lock.release() self._pre_handle_finished_reqs(finished_reqs=finished_reqs) - g_infer_context.filter_reqs(finished_reqs=finished_reqs) + # 如果使能了 cpu cache 功能,对于已经完成的请求,进行 gpu kv 卸载到 cpu cache的操作。 + if self.args.enable_cpu_cache: + true_finished_reqs = self.multi_level_cache_module.offload_finished_reqs_to_cpu_cache( + finished_reqs=finished_reqs + ) + else: + true_finished_reqs = finished_reqs + g_infer_context.filter_reqs(finished_reqs=true_finished_reqs) g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp) if recover_paused: diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py index 1ff167960..1e5cccb1f 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_infer_rpyc.py @@ -131,7 +131,7 @@ def _put_kv_received_to_radix_cache(self, group_req_id: int): radix_cache = self.backend.radix_cache key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") - prefix_len = radix_cache.insert(key, value) + prefix_len, _ = radix_cache.insert(key, value) assert len(fused_token_indexes) <= prefix_len self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len]) self.backend.radix_cache.dec_node_ref_counter(tree_node) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py index 8addf7b18..a2ff08bd2 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py @@ -74,17 +74,23 @@ def _prefill_req_frozen_tokens_and_put_to_kvmove_taskqueue(self, finished_reqs: key = req.get_input_token_ids()[0 : req.cur_kv_len] key = torch.tensor(key, dtype=torch.int64, device="cpu") value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len = self.radix_cache.insert(key, value) + prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len self.model.mem_manager.free( self.model.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len] ) - if req.shared_kv_node is not None: - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - req.shared_kv_node = None + # 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len - req.cur_kv_len = 0 - req.shm_req.shm_cur_kv_len = 0 + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + self.radix_cache.add_node_ref_counter(new_shared_kv_node) + req.shared_kv_node = new_shared_kv_node + + _kv_len = req.cur_kv_len + _value = self.radix_cache.get_mem_index_value_by_node(new_shared_kv_node) + assert len(_value) == _kv_len + self.model.req_manager.req_to_token_indexs[req.req_idx][0:_kv_len] = _value + + assert new_shared_kv_node.node_prefix_total_len == req.cur_kv_len if req.shm_req.sample_params.move_kv_to_decode_node.exists: # 注意兼容纯tp 和 tp dp 混合模式的逻辑 diff --git a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py index 85bb55191..c973c4e9c 100644 --- a/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py @@ -3,10 +3,9 @@ from lightllm.server.router.model_infer.infer_batch import ( g_infer_context, InferReq, - InferReqGroup, + InferReqUpdatePack, ) from typing import List, Tuple -from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.server.router.model_infer.mode_backend.pre import ( prepare_prefill_inputs, ) @@ -15,6 +14,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from ..chunked_prefill.impl import ChunkedPrefillBackend +from lightllm.common.basemodel.infer_lock import g_infer_state_lock class DiversehBackend(ChunkedPrefillBackend): @@ -23,32 +23,9 @@ def __init__(self) -> None: self.prefill = self.beam_prefill self.classed_req_strict_prefill = True - def diverse_copy(self, groups: List[InferReqGroup]) -> Tuple[List[int], List[InferReq]]: - batch_idx = [] - run_reqs = [] - for i in range(len(groups)): - req_group = groups[i] - best_of = req_group.best_of() - if best_of > 1: - req_group.diverse_copy(g_infer_context.req_manager, is_prefill=True) - batch_idx.extend([i for _ in range(best_of)]) - else: - batch_idx.append(i) - run_reqs.extend(req_group.get_all_reqs()) - return batch_idx, run_reqs - def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq]): # 第一阶段 - group_reqs = [ - g_infer_context.requests_mapping[req.req_id] - for req in prefill_reqs - if convert_sub_id_to_group_id(req.req_id) == req.req_id - ] - groups = [ - g_infer_context.group_mapping[req.req_id] - for req in prefill_reqs - if convert_sub_id_to_group_id(req.req_id) == req.req_id - ] + group_reqs = [g_infer_context.requests_mapping[req.req_id] for req in prefill_reqs if req.is_master_req()] model_input, group_run_reqs = prepare_prefill_inputs( group_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal @@ -59,7 +36,9 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq model_output = self.model.forward(model_input) logits = model_output.logits - batch_idx, run_reqs = self.diverse_copy(groups) + batch_idx, run_reqs = self._diverse_copy( + master_reqs=group_reqs, b_prefill_has_out=model_input.b_prefill_has_output_cpu + ) b_req_idx = [req.req_idx for req in run_reqs] b_has_out = [model_input.b_prefill_has_output_cpu[i] for i in batch_idx] @@ -95,8 +74,8 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq # 第二阶段 event_pack.notify_post_handle_and_wait_pre_post_handle() - update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) - + model_output.prefill_mem_indexes_ready_event.synchronize() + update_packs = self._diverse_pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() @@ -110,3 +89,122 @@ def beam_prefill(self, event_pack: OverlapEventPack, prefill_reqs: List[InferReq # 第四阶段 event_pack.notify_pre_post_handle() return + + def _diverse_copy( + self, master_reqs: List[InferReq], b_prefill_has_out: List[bool] + ) -> Tuple[List[int], List[InferReq]]: + batch_idx = [] + run_reqs = [] + for i in range(len(master_reqs)): + master_req = master_reqs[i] + slave_reqs = master_req.slave_reqs + slave_num = len(slave_reqs) + batch_idx.append(i) + run_reqs.append(master_req) + + if slave_num > 0 and b_prefill_has_out[i]: + batch_idx.extend([i for _ in range(slave_num)]) + run_reqs.extend(slave_reqs) + + return batch_idx, run_reqs + + # 一些可以复用的通用功能函数 + + def _diverse_pre_post_handle(self, run_reqs: List[InferReq], is_chuncked_mode: bool) -> List[InferReqUpdatePack]: + update_func_objs: List[InferReqUpdatePack] = [] + # 通用状态预先填充 + is_master_in_dp = self.is_master_in_dp + pre_master_req_pack = None + for req_obj in run_reqs: + req_obj: InferReq = req_obj + if req_obj.is_master_req(): + if is_chuncked_mode: + new_kv_len = req_obj.get_chuncked_input_token_len() + else: + new_kv_len = req_obj.get_cur_total_len() + req_obj.cur_kv_len = new_kv_len + if is_master_in_dp: + req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len + + # 对于没有到达需要输出 token 阶段的请求,直接略过, 说明还 + # 处于chuncked prefill kv 填充的阶段。 + if req_obj.cur_kv_len < req_obj.get_cur_total_len(): + pack = InferReqUpdatePack(req_obj=req_obj, output_len=0) + update_func_objs.append(pack) + pre_master_req_pack = pack + # TODO 如果 diverse mode 需要支持 nixl pd 分离,则应该每个分块prefill后都进行相关的复制, + # 暂时不支持 diverse mode 和 pd 模式的混合 + continue + + # 将生成的下一个token的信息写入到管理对象中。 + req_obj.cur_output_len += 1 + pack = InferReqUpdatePack(req_obj=req_obj, output_len=req_obj.cur_output_len) + update_func_objs.append(pack) + pre_master_req_pack = pack + if req_obj.slave_reqs: + # 存在 slave reqs 的 master req 需要将自己的 kv 信息写入到 radix cache 中 + # 方便 slave req 进行 kv 的复用 + self._master_req_to_radix_cache(master_req=req_obj) + else: + # slave req 直接复用 master req 的更新包。 + assert pre_master_req_pack is not None + assert pre_master_req_pack.req_obj.shm_req.group_req_id == req_obj.shm_req.group_req_id + self._copy_master_req_to_slave_req(slave_req=req_obj) + # 在拷贝后,请求独立了,与 master_req 的关系解除 + req_obj.remove_master_req() + pack = InferReqUpdatePack(req_obj=req_obj, output_len=pre_master_req_pack.output_len) + update_func_objs.append(pack) + + torch.cuda.current_stream().synchronize() + return update_func_objs + + def _master_req_to_radix_cache(self, master_req: InferReq): + g_infer_state_lock.acquire() + key = master_req.get_input_token_ids()[0 : master_req.cur_kv_len] + key = torch.tensor(key, dtype=torch.int64, device="cpu") + value = self.model.req_manager.req_to_token_indexs[master_req.req_idx][: master_req.cur_kv_len].detach().cpu() + prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if master_req.shared_kv_node is None else master_req.shared_kv_node.node_prefix_total_len + assert old_prefix_len <= master_req.cur_kv_len + self.model.mem_manager.free( + self.model.req_manager.req_to_token_indexs[master_req.req_idx][old_prefix_len:prefix_len] + ) + + # 将原有共享节点替换为新共享节点,新共享节点对应的长度为当前的cur_kv_len + self.radix_cache.dec_node_ref_counter(master_req.shared_kv_node) + self.radix_cache.add_node_ref_counter(new_shared_kv_node) + master_req.shared_kv_node = new_shared_kv_node + assert ( + new_shared_kv_node.node_prefix_total_len == master_req.cur_kv_len + ), f"shared len: {new_shared_kv_node.node_prefix_total_len} cur_kv_len {master_req.cur_kv_len}" + + share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=False) + assert share_node == new_shared_kv_node and kv_len == master_req.cur_kv_len + self.model.req_manager.req_to_token_indexs[master_req.req_idx][0 : master_req.cur_kv_len] = value + g_infer_state_lock.release() + return + + def _copy_master_req_to_slave_req(self, slave_req: InferReq): + g_infer_state_lock.acquire() + master_req = slave_req.related_master_req + assert master_req is not None + + self.radix_cache.dec_node_ref_counter(slave_req.shared_kv_node) + self.radix_cache.add_node_ref_counter(master_req.shared_kv_node) + + kv_len = master_req.cur_kv_len + + self.model.req_manager.req_to_token_indexs[slave_req.req_idx][ + 0:kv_len + ] = self.model.req_manager.req_to_token_indexs[master_req.req_idx][0:kv_len] + # torch.cuda.current_stream().synchronize() + slave_req.shared_kv_node = master_req.shared_kv_node + slave_req.cur_kv_len = kv_len + slave_req.cur_output_len = master_req.cur_output_len + if self.is_master_in_dp: + slave_req.shm_req.shm_cur_kv_len = slave_req.cur_kv_len + + assert kv_len <= slave_req.shm_req.input_len + + g_infer_state_lock.release() + return diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py new file mode 100644 index 000000000..9d0d681dc --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -0,0 +1,270 @@ +import threading +import torch.distributed as dist +import torch +import dataclasses +from typing import Optional, List, Deque +from collections import deque +from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient +from lightllm.utils.envs_utils import get_env_start_args, disable_cpu_kvcache_sync +from ..infer_batch import InferReq +from lightllm.utils.dist_utils import create_new_group_for_current_dp +from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MultiLevelKvCacheModule(object): + def __init__(self, backend): + self.args = get_env_start_args() + from .base_backend import ModeBackend + + self.backend: ModeBackend = backend + self.gloo_group = create_new_group_for_current_dp("gloo") + self.filter_group = create_new_group_for_current_dp("gloo") + self.init_sync_group = create_new_group_for_current_dp("nccl") + dist.barrier(group=self.init_sync_group) + + self.cpu_cache_handle_queue: Deque[TransTask] = deque() + self.cpu_cache_client = CpuKvCacheClient(only_create_meta_data=False, init_shm_data=False) + + # 一些算子模式需要同步计算和 cpu cache 的 load 和 offload 操作 + self.need_sync_compute_stream: bool = self.args.enable_fa3 and not disable_cpu_kvcache_sync() + + def wait(self): + """ + 等待 cpu cache 相关页面注册完成 + """ + attach_shm_handle = self.cpu_cache_client.attach_shm_handle + if attach_shm_handle is not None: + attach_shm_handle.wait() + + def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): + idle_token_num = g_infer_context.get_can_alloc_token_num() + token_page_size = self.args.cpu_cache_token_page_size + all_page_list = [] + is_master_in_dp = self.backend.is_master_in_dp + for req in reqs: + page_list = req.shm_req.cpu_cache_match_page_indexes.get_all() + match_tokens = len(page_list) * token_page_size + # 更新命中的 cpu kv cache 长度. + if is_master_in_dp: + req.shm_req.cpu_prompt_cache_len = match_tokens + + need_token_num = match_tokens - req.cur_kv_len + # 多匹配了一定数量的token同时请求长度大于一定的长度,才进行复制操作,不然操作效率不高,代价过高 + if need_token_num >= 256 and req.shm_req.input_len >= 512: + if need_token_num <= idle_token_num: + if self.backend.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(need_token_num=need_token_num) + + # 计算需要加载的页面(只加载未匹配的部分) + cur_kv_pages = req.cur_kv_len // token_page_size + need_pages = page_list[cur_kv_pages:] # 只取需要的页面 + + mem_indexes = g_infer_context.req_manager.mem_manager.alloc(need_size=need_token_num) + + if self.need_sync_compute_stream: + # TODO fa3 现在必须使用同步模式, 未来需要移除 + g_infer_context.get_overlap_stream().synchronize() + + # TODO 更有效的分配策略。 + grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1 + + # 将 cpu page 的内容拷贝到 gpu 页面中 + load_cpu_kv_to_gpu( + gpu_mem_indexes=mem_indexes.cuda(non_blocking=True), + gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, + cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + page_indexes=torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda(non_blocking=True), + tp_index=self.backend.rank_in_dp, + tp_world_size=self.backend.dp_world_size, + grid_num=grid_num, + ) + + torch.cuda.current_stream().synchronize() + + idle_token_num -= need_token_num + g_infer_context.req_manager.req_to_token_indexs[ + req.req_idx, req.cur_kv_len : (req.cur_kv_len + need_token_num) + ] = mem_indexes + req.cur_kv_len = req.cur_kv_len + need_token_num + if self.backend.is_master_in_dp: + req.shm_req.shm_cur_kv_len = req.cur_kv_len + + all_page_list.extend(page_list) + + dist.barrier(group=self.init_sync_group) + + if self.backend.is_master_in_dp: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.deref_pages(page_list=all_page_list) + self.cpu_cache_client.lock.release() + return + + def offload_finished_reqs_to_cpu_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: + """ + 将满足cpu kv cache 卸载条件的请求进行处理, 并返回真的满足退出条件的请求list。 + """ + # 如果开启了cpu cache,将达到finished状态的请求开启将gpu kv cache 卸载到 cpu cache中的操作。 + # 当 kv cache 卸载完成后,才会进行请求的真实退出操作。 + true_finished_reqs = [] + cpu_stream = g_infer_context.get_cpu_kv_cache_stream() + for req in finished_reqs: + # 只有 group_req_id 和 request_id 相同的请求才会被卸载到 cpu cache 中。 + # 这个限制是为了兼容 diverse 模式下的请求处理, 只有主请求才 offload kv 到 cpu + # cache 中 + if req.shm_req.group_req_id != req.shm_req.request_id: + true_finished_reqs.append(req) + continue + + # 过滤不适合进行 kv 卸载到 cpu cache 的请求。 + if ( + req.cur_kv_len < self.args.cpu_cache_token_page_size + or req.shm_req.input_len <= self.args.cpu_cache_token_page_size + ): + true_finished_reqs.append(req) + continue + + # 如果请求已经完成了 cpu cache 的任务,则满足了退出条件 + if req.cpu_cache_task_status.is_finished(): + true_finished_reqs.append(req) + continue + + # 如果请求已经发起过卸载任务且正在卸载过程中,则在当前轮不进行处理 + if req.cpu_cache_task_status.is_running(): + continue + + assert req.cpu_cache_task_status.is_not_started() + + if self.need_sync_compute_stream: + # TODO fa3 现在必须使用同步模式, 未来需要移除, 必须等待 overlap stream 上的计算任务完成,不然会崩溃 + g_infer_context.get_overlap_stream().synchronize() + + # 发起将请求的 kv cache 卸载到 cpu cache 中的任务 + trans_task = self._start_kv_cache_offload_task(req=req, cpu_kv_cache_stream=cpu_stream) + + # 根据是否成功创建了卸载任务,决定是否将请求加入到处理队列中 + if trans_task is not None: + self.cpu_cache_handle_queue.append(trans_task) + else: + true_finished_reqs.append(req) + + if self.need_sync_compute_stream: + # TODO fa3 现在必须使用同步模式, 未来需要移除 + cpu_stream.synchronize() + + return true_finished_reqs + + def _start_kv_cache_offload_task( + self, req: InferReq, cpu_kv_cache_stream: torch.cuda.Stream + ) -> Optional["TransTask"]: + with torch.cuda.stream(cpu_kv_cache_stream): + if self.backend.is_master_in_dp: + # 综合考虑后只对prompt做缓存管理,不包含decode内容,这里与radix cache不一致 + token_hash_list = req.shm_req.token_hash_list.get_all() + block_size = req.cur_kv_len // self.args.cpu_cache_token_page_size + move_block_size = min(block_size, len(token_hash_list)) + + if move_block_size == 0: + dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + + try: + self.cpu_cache_client.lock.acquire_sleep1ms() + page_list, ready_list = self.cpu_cache_client.allocate_pages( + token_hash_list[:move_block_size], + disk_offload_enable=self.args.enable_disk_cache, + ) + finally: + self.cpu_cache_client.lock.release() + + item_size = len(page_list) + if item_size == 0: + dist.broadcast_object_list([0], group=self.gloo_group, group_src=0) + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + + broadcast_data = {"item_size": item_size, "page_list": page_list, "ready_list": ready_list} + dist.broadcast_object_list([broadcast_data], group=self.gloo_group, group_src=0) + else: + recv_list = [None] + dist.broadcast_object_list(recv_list, group=self.gloo_group, group_src=0) + if isinstance(recv_list[0], int) and recv_list[0] == 0: + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return None + broadcast_data = recv_list[0] + item_size = broadcast_data["item_size"] + page_list = broadcast_data["page_list"] + ready_list = broadcast_data["ready_list"] + + page_indexes = torch.tensor(page_list, dtype=torch.int32, device="cpu", pin_memory=True) + page_readies = torch.tensor(ready_list, dtype=torch.bool, device="cpu", pin_memory=True) + move_token_num = item_size * self.args.cpu_cache_token_page_size + assert req.cur_kv_len >= item_size * self.args.cpu_cache_token_page_size + token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:move_token_num] + + # TODO 更有效的分配策略。 + grid_num = 16 if self.need_sync_compute_stream or (not self.args.enable_fa3) else 1 + + # assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0] + offload_gpu_kv_to_cpu( + token_indexes=token_indexes, + gpu_kv_cache=self.backend.model.mem_manager.kv_buffer, + cpu_kv_cache=self.cpu_cache_client.cpu_kv_cache_tensor, + page_indexes=page_indexes, + page_readies=page_readies, + tp_index=self.backend.rank_in_dp, + tp_world_size=self.backend.dp_world_size, + grid_num=grid_num, + ) + + sync_event = torch.cuda.Event() + sync_event.record() + req.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING + trans_task = TransTask( + page_indexes=page_indexes, page_readies=page_readies, req_obj=req, sync_event=sync_event + ) + + return trans_task + + def update_cpu_cache_task_states(self): + if self.backend.is_master_in_dp: + trans_ok_tasks = [] + while len(self.cpu_cache_handle_queue) != 0: + task: TransTask = self.cpu_cache_handle_queue.popleft() + if task.sync_event.query(): + trans_ok_tasks.append(task) + else: + self.cpu_cache_handle_queue.appendleft(task) + break + item_size = len(trans_ok_tasks) + dist.broadcast_object_list([item_size], group=self.filter_group, group_src=0) + else: + recv_list = [None] + dist.broadcast_object_list(recv_list, group=self.filter_group, group_src=0) + item_size = recv_list[0] + trans_ok_tasks: List[TransTask] = [self.cpu_cache_handle_queue.popleft() for _ in range(item_size)] + + if item_size > 0: + page_array_list = [task.page_indexes for task in trans_ok_tasks] + page_list = torch.cat(page_array_list, dim=0).tolist() + if self.backend.is_master_in_dp: + self.cpu_cache_client.lock.acquire_sleep1ms() + self.cpu_cache_client.update_pages_status_to_ready( + page_list=page_list, deref=True, disk_offload_enable=self.args.enable_disk_cache + ) + self.cpu_cache_client.lock.release() + for task in trans_ok_tasks: + task.req_obj.cpu_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + return + + +@dataclasses.dataclass +class TransTask: + page_indexes: torch.Tensor + page_readies: torch.Tensor + req_obj: InferReq + sync_event: torch.cuda.Event diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index cf5435601..36aefae6e 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -4,10 +4,11 @@ from lightllm.server.core.objs import FinishStatus from lightllm.common.basemodel.infer_lock import g_router_lock from lightllm.utils.config_utils import get_fixed_kv_len +from lightllm.server.core.objs import StartArgs class BaseQueue: - def __init__(self, args, router, dp_index, dp_size_in_node) -> None: + def __init__(self, args: StartArgs, router, dp_index, dp_size_in_node) -> None: self.args = args self.dp_index = dp_index self.dp_size_in_node = dp_size_in_node @@ -26,6 +27,13 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None: self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy self.router_max_new_token_len = args.router_max_new_token_len + def free_aborted_req_cpu_cache_pages(self, req: Req): + if self.args.enable_cpu_cache: + self.router.cpu_cache_client.lock.acquire_sleep1ms() + self.router.cpu_cache_client.deref_pages(req.cpu_cache_match_page_indexes.get_all()) + req.cpu_cache_match_page_indexes.clear() + self.router.cpu_cache_client.lock.release() + def extend(self, req_group: List[Req]): for req in req_group: req.sample_params.suggested_dp_index = self.dp_index diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index 143f6081c..ae7c90b33 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -43,9 +43,7 @@ def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new cumsum_len += cur_input_len - req.input_len # 减去共享的部分 need_max_token_num = max(need_max_token_num, cumsum_len + index * cur_ouput_len) - # prefill token 计算 - for req in cur_handle_group_reqs: - new_batch_first_router_need_tokens += req.shm_cur_output_len + # prefill token 计算, 因为对beam的prefill计算过程是共享的,所以只计算一个请求对应的token数量 new_batch_first_router_need_tokens += req.get_first_router_need_tokens() ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) @@ -121,6 +119,7 @@ def generate_new_batch(self, current_batch: Batch): new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: + self.free_aborted_req_cpu_cache_pages(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 3730300a5..0d870b55d 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -96,6 +96,7 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: + self.free_aborted_req_cpu_cache_pages(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index d8507cd6c..f2658159b 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -87,6 +87,7 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: + self.free_aborted_req_cpu_cache_pages(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index e89dda66e..e0da13487 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -52,6 +52,7 @@ def generate_new_batch(self, current_batch: Batch): if len(can_run_list) != 0: new_batch = Batch(uuid.uuid4().int, can_run_list, dp_size_in_node=self.dp_size_in_node) for req in abort_req_list: + self.free_aborted_req_cpu_cache_pages(req) self.router.shm_req_manager.put_back_req_obj(req) self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] return new_batch diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 652c59d5e..1107770b3 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -8,7 +8,7 @@ import setproctitle from typing import List from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes -from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import ShmReqManager, StartArgs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem @@ -26,20 +26,26 @@ class VisualManager: def __init__( self, - args, - next_module_port, - visual_port, - cache_port, + args: StartArgs, visual_model_rpc_ports, ): context = zmq.Context(2) - self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.cache_port = cache_port + if args.enable_multimodal_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + else: + if args.enable_cpu_cache: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + else: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") + + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) + self.cache_port = args.cache_port self.waiting_reqs: List[GroupReqIndexes] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp @@ -163,7 +169,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -182,13 +188,13 @@ def clean_up(self): return -def start_visual_process(args, next_module_port, visual_port, cache_port, model_rpc_ports, pipe_writer): +def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() try: - visualserver = VisualManager(args, next_module_port, visual_port, cache_port, model_rpc_ports) + visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) diff --git a/lightllm/utils/auto_shm_cleanup.py b/lightllm/utils/auto_shm_cleanup.py new file mode 100644 index 000000000..2417fef08 --- /dev/null +++ b/lightllm/utils/auto_shm_cleanup.py @@ -0,0 +1,135 @@ +import os +import ctypes +import atexit +import signal +import threading +import psutil +from multiprocessing import shared_memory +from typing import Set, Optional +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class AutoShmCleanup: + """ + 自动清理 System V 和 POSIX 共享内存 + shared_memory.SharedMemory虽然有自动请理功能,但如果自动清理时仍有进程占用会清理失败,这里可做最后兜底清理 + """ + + def __init__(self): + self.libc = None + self._init_libc() + # System V + self.registered_shm_keys = [] + self.registered_shm_ids = [] + # POSIX + self.registered_posix_shm_names = [] + self.signal_handlers_registered = False + self._register_handlers_for_cleanup() + + def _init_libc(self): + try: + self.libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6") + self.libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) + self.libc.shmget.restype = ctypes.c_int + self.libc.shmctl.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_void_p) + self.libc.shmctl.restype = ctypes.c_int + except Exception as e: + logger.debug(f"libc init failed: {e}") + self.libc = None + + def _register_handlers_for_cleanup(self): + atexit.register(self._cleanup) + self.register_signal_handlers() + + def register_signal_handlers(self): + if self.signal_handlers_registered or not threading.current_thread() is threading.main_thread(): + return + for sig in (signal.SIGTERM, signal.SIGINT, signal.SIGHUP): + signal.signal(sig, self._signal_cleanup_handler) + self.signal_handlers_registered = True + + def _signal_cleanup_handler(self, signum, frame): + self._cleanup() + parent = psutil.Process(os.getpid()) + # 递归拿到所有子进程并终止 + for ch in parent.children(recursive=True): + ch.kill() + + def _cleanup(self): + """清理:System V 执行 IPC_RMID,POSIX 执行 unlink。""" + removed_sysv = 0 + IPC_RMID = 0 + for shmid in self.registered_shm_ids: + try: + if self.libc.shmctl(shmid, IPC_RMID, None) == 0: + removed_sysv += 1 + except Exception as e: + logger.warning(f"cleanup: shmid {shmid} clean failed, reason: {e}") + pass + for key in self.registered_shm_keys: + shmid = self.libc.shmget(key, 0, 0) + try: + if shmid >= 0 and self.libc.shmctl(shmid, IPC_RMID, None) == 0: + removed_sysv += 1 + except Exception as e: + logger.warning(f"cleanup: shmid {shmid} clean failed, reason: {e}") + pass + if removed_sysv: + logger.info(f"cleanup: removed {removed_sysv} System V shm segments") + + removed_posix = 0 + for name in self.registered_posix_shm_names: + try: + shm = shared_memory.SharedMemory(name=name, create=False) + try: + shm.unlink() + removed_posix += 1 + except FileNotFoundError: + pass + except Exception as e: + logger.warning(f"cleanup: posix shm {name} clean failed, reason: {e}") + pass + finally: + shm.close() + except FileNotFoundError: + pass + except Exception as e: + logger.warning(f"cleanup: posix {name} clean failed, reason: {e}") + pass + if removed_posix: + logger.info(f"cleanup: unlinked {removed_posix} POSIX shm segments") + + def register_sysv_shm(self, key: int, shmid: Optional[int] = None): + """注册 System V 共享内存。""" + self.registered_shm_keys.append(key) + if shmid is not None: + self.registered_shm_ids.append(shmid) + return + + def register_posix_shm(self, name: str): + """注册 POSIX 共享内存。""" + self.registered_posix_shm_names.append(name) + return + + +# 全局自动清理器实例 +_auto_cleanup = None + + +def get_auto_cleanup() -> AutoShmCleanup: + """获取全局自动清理器实例""" + global _auto_cleanup + if _auto_cleanup is None: + _auto_cleanup = AutoShmCleanup() + _auto_cleanup.register_signal_handlers() + return _auto_cleanup + + +def register_sysv_shm_for_cleanup(key: int, shmid: Optional[int] = None): + get_auto_cleanup().register_sysv_shm(key, shmid) + + +def register_posix_shm_for_cleanup(name: str): + get_auto_cleanup().register_posix_shm(name) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index afaf71a25..a4a51a5da 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -1,6 +1,6 @@ import json import os -from typing import Optional +from typing import Optional, List from functools import lru_cache from .envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -14,52 +14,83 @@ def get_config_json(model_path: str): return json_obj -def get_hidden_size(model_path: str) -> Optional[int]: - # try to get hidden_size in config.json +def _get_config_llm_keyvalue(model_path: str, key_name: str): config_json = get_config_json(model_path) try: - hidden_size = config_json["hidden_size"] + value = config_json[key_name] except: # for some multimodal model try: - hidden_size = config_json["llm_config"]["hidden_size"] + value = config_json["llm_config"][key_name] except: - hidden_size = config_json.get("text_config", {}).get("hidden_size") + value = config_json.get("text_config", {}).get(key_name) + + if value is None: + logger.error(f"cannot get {key_name} from config.json, return None") + + return value + +def get_hidden_size(model_path: str) -> Optional[int]: + hidden_size = _get_config_llm_keyvalue(model_path=model_path, key_name="hidden_size") if isinstance(hidden_size, int): return hidden_size - - logger.error("cannot get hidden size from config.json, return None instead") return None @lru_cache(maxsize=None) def get_num_key_value_heads(model_path: str) -> int: - config_json = get_config_json(model_path) - try: - num_key_value_heads = config_json["num_key_value_heads"] - except: - # for some multimodal model - num_key_value_heads = config_json["llm_config"]["num_key_value_heads"] - return num_key_value_heads + num_key_value_heads = _get_config_llm_keyvalue(model_path=model_path, key_name="num_key_value_heads") + if isinstance(num_key_value_heads, int): + return num_key_value_heads + return None -def get_eos_token_ids(model_path: str): - config_json = get_config_json(model_path) - try: - eos_token_id = config_json["eos_token_id"] - except: - # for some multimode model. - try: - eos_token_id = config_json["llm_config"]["eos_token_id"] - except: - eos_token_id = config_json["text_config"]["eos_token_id"] +@lru_cache(maxsize=None) +def get_num_attention_heads(model_path: str) -> int: + num_attention_heads = _get_config_llm_keyvalue(model_path=model_path, key_name="num_attention_heads") + if isinstance(num_attention_heads, int): + return num_attention_heads + return None + +@lru_cache(maxsize=None) +def get_head_dim(model_path: str) -> int: + head_dim = _get_config_llm_keyvalue(model_path=model_path, key_name="head_dim") + if isinstance(head_dim, int): + return head_dim + + # calcu head_dim + head_dim = get_hidden_size(model_path=model_path) // get_num_attention_heads(model_path=model_path) + + return head_dim + + +@lru_cache(maxsize=None) +def get_layer_num(model_path: str) -> int: + num_hidden_layers = _get_config_llm_keyvalue(model_path=model_path, key_name="num_hidden_layers") + if isinstance(num_hidden_layers, int): + return num_hidden_layers + return None + + +@lru_cache(maxsize=None) +def get_model_type(model_path: str) -> str: + model_type = _get_config_llm_keyvalue(model_path=model_path, key_name="model_type") + if isinstance(model_type, str): + return model_type + return None + + +def get_eos_token_ids(model_path: str) -> Optional[List[int]]: + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name="eos_token_id") if isinstance(eos_token_id, int): return [eos_token_id] if isinstance(eos_token_id, list): return eos_token_id + assert False, "error eos_token_id format in config.json" + return def get_model_architectures(model_path: str): @@ -88,13 +119,12 @@ def get_vocab_size(model_path: str): def get_dtype(model_path: str): - config_json = get_config_json(model_path) - try: - torch_dtype = config_json["torch_dtype"] - return torch_dtype - except: + torch_dtype = _get_config_llm_keyvalue(model_path=model_path, key_name="torch_dtype") + if torch_dtype is None: logger.warning("torch_dtype not in config.json, use float16 as default") return "float16" + else: + return torch_dtype @lru_cache(maxsize=None) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 3878c8f8b..2f795aa23 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -173,3 +173,11 @@ def use_whisper_sdpa_attention() -> bool: whisper重训后,使用特定的实现可以提升精度,用该函数控制使用的att实现。 """ return enable_env_vars("LIGHTLLM_USE_WHISPER_SDPA_ATTENTION") + + +@lru_cache(maxsize=None) +def disable_cpu_kvcache_sync() -> bool: + """ + 实验用环境遍历,未来可能会移除 + """ + return enable_env_vars("LIGHTLLM_DISABLE_CPU_CACHE_SYNC") diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py new file mode 100644 index 000000000..3670e804f --- /dev/null +++ b/lightllm/utils/kv_cache_utils.py @@ -0,0 +1,274 @@ +import torch +import ctypes +import dataclasses +import os +import xxhash +import threading +import time +import numpy as np +import triton +from functools import lru_cache +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.utils.config_utils import get_num_key_value_heads, get_head_dim, get_layer_num, get_model_type +from typing import List, Tuple, Optional +from tqdm import tqdm +from lightllm.utils.auto_shm_cleanup import register_sysv_shm_for_cleanup +from lightllm.utils.dist_utils import get_current_device_id + +logger = init_logger(__name__) + + +def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) -> List[int]: + if len(tokens) == 0: + return [] + + chunks_hash_value = [] + hsum = xxhash.xxh3_64() + + # 计算每个分块的哈希值, 但是输入token需要少一个,因为 + # 如果计算所有的token,会导致输入input_len 命中全长的 + # cpu cache, 导致prefill 过程无法有输入来导出下一个输出。 + calcu_num = (len(tokens) - 1) // cpu_cache_token_page_size + + for i in range(calcu_num): + start_index = i * cpu_cache_token_page_size + end_index = (i + 1) * cpu_cache_token_page_size + chunk = tokens[start_index:end_index] + chunk_np = np.array(chunk, dtype=np.uint64) + hsum.update(chunk_np.tobytes()) + + hash_value = hsum.intdigest() + chunks_hash_value.append(hash_value) + + return chunks_hash_value + + +@lru_cache(maxsize=None) +def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": + args = get_env_start_args() + assert args.enable_cpu_cache + + if get_model_type(model_path=args.model_dir) in ["deepseek_v2", "deepseek_v3"]: + item_size = 2 + num_key_value_heads = 1 + head_dim = 512 + 64 + layer_num = get_layer_num(args.model_dir) + else: + item_size = 2 + num_key_value_heads = get_num_key_value_heads(args.model_dir) * 2 + head_dim = get_head_dim(args.model_dir) + layer_num = get_layer_num(args.model_dir) + + if args.mtp_mode is not None: + # TODO 可能会存在不同mtp模式的精度问题 + layer_num += 1 + + one_token_byte_size = layer_num * num_key_value_heads * head_dim * item_size + one_page_byte_size = args.cpu_cache_token_page_size * one_token_byte_size + cpu_cache_page_num = int((args.cpu_cache_storage_size * 1024 * 1024 * 1024) / one_page_byte_size) + + cpu_cache_meta = CpuKVCacheMeta( + page_num=cpu_cache_page_num, + layer_num=layer_num, + token_page_size=args.cpu_cache_token_page_size, + num_heads=num_key_value_heads, + head_dim=head_dim, + item_size=item_size, + ) + + logger.info(f"cpu kv cache page num: {cpu_cache_meta.page_num}") + + return cpu_cache_meta + + +@lru_cache(maxsize=None) +def create_shm_kv_cache_ptr() -> int: + libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True) + libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) + libc.shmget.restype = ctypes.c_int + libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) + libc.shmat.restype = ctypes.c_void_p + + args = get_env_start_args() + key = args.cpu_kv_cache_shm_id + requested_size = calcu_cpu_cache_meta().calcu_size() + use_hugetlb = True + + # 计算大页大小(默认从 /proc/meminfo 读取 Hugepagesize) + def _get_default_hugepage_size() -> int: + try: + with open("/proc/meminfo", "r") as f: + for line in f: + if line.startswith("Hugepagesize:"): + parts = line.split() + if len(parts) >= 2: + kb = int(parts[1]) + return kb * 1024 + except Exception: + pass + return 2 * 1024 * 1024 # fallback 2MB + + # 向上对齐到大页大小 + huge_sz = _get_default_hugepage_size() + size_to_alloc = triton.cdiv(requested_size, huge_sz) * huge_sz + shmflg = 0o666 | 0o1000 # 权限和 IPC_CREAT 标志 + if use_hugetlb: + SHM_HUGETLB = 0o4000 + shmflg |= SHM_HUGETLB + logger.info( + f"Using SHM_HUGETLB, hugepage_size={huge_sz} bytes, requested={requested_size}, alloc={size_to_alloc}" + ) + + # 优先尝试 HugeTLB 分配,失败则回退到普通页 + shmid = libc.shmget(key, size_to_alloc, shmflg) + hugepages_num = (size_to_alloc + 1024 * 1024 * 1024 - 1) // (1024 * 1024 * 1024) + if shmid < 0 and use_hugetlb: + err = ctypes.get_errno() + logger.error( + f"shmget with SHM_HUGETLB failed (errno={err}). Falling back to regular pages." + f"You may need to configure hugepages manually, e.g.," + f"sudo sed -i 's/^GRUB_CMDLINE_LINUX=\"/& default_hugepagesz=1G \ + hugepagesz=1G hugepages={hugepages_num}/' /etc/default/grub" + f"sudo update-grub" + f"sudo reboot" + ) + # 回退:去掉 HUGETLB 标志,使用请求原始大小 + shmflg_n = 0o666 | 0o1000 + shmid = libc.shmget(key, size_to_alloc, shmflg_n) + + if shmid < 0: + err = ctypes.get_errno() + raise Exception(f"Error creating shared memory (errno={err})") + + register_sysv_shm_for_cleanup(key, shmid) + logger.info(f"Shared memory ID: {shmid}") + + # 附加共享内存 + shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) + if shm_addr == ctypes.c_void_p(-1).value: + raise Exception("Error attaching shared memory") + logger.info(f"Shared cpu kv cache tensor memory at address: {shm_addr}") + + return shm_addr + + +@dataclasses.dataclass +class CpuKVCacheMeta: + page_num: int + layer_num: int + token_page_size: int + num_heads: int + head_dim: int + item_size: int + + def calcu_size(self): + return self.page_num * self.layer_num * self.token_page_size * self.num_heads * self.head_dim * self.item_size + + +@lru_cache(maxsize=None) +def register_shm_ptr_to_pin(shm_ptr: int, size: int) -> "AsyncRegistrationHandle": + """Start async cudaHostRegister on the given [shm_ptr, shm_ptr+size) and return a handle.""" + chunk_bytes = 128 * 1024 * 1024 # 128M性能最好 + tasks: list[tuple[int, int]] = [] + offset = 0 + while offset < size: + seg_len = min(chunk_bytes, size - offset) + tasks.append((offset, seg_len)) + offset += seg_len + + handle = AsyncRegistrationHandle(total_tasks=len(tasks)) + + def _worker(): + cuda = ctypes.CDLL("/usr/local/cuda/targets/x86_64-linux/lib/libcudart.so") + cuda.cudaHostRegister.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_uint] + cuda.cudaHostRegister.restype = ctypes.c_int + cuda.cudaHostGetDevicePointer.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_void_p, ctypes.c_int] + cuda.cudaHostGetDevicePointer.restype = ctypes.c_int + + cudaHostRegisterFlag = 3 + + torch.cuda.set_device(get_current_device_id()) + # TODO 这个地方的分块注册是否具备合法性和合理性。 + for offset, seg_len in tasks: + ptr = ctypes.c_void_p(shm_ptr + offset) + r = cuda.cudaHostRegister(ptr, ctypes.c_size_t(seg_len), cudaHostRegisterFlag) + if r != 0: + raise Exception(f"cudaHostRegister failed with error code {r}, prefer to use hugetlb") + handle.task_count += 1 + + device_ptr = ctypes.c_void_p() + host_ptr = ctypes.c_void_p(shm_ptr) + res = cuda.cudaHostGetDevicePointer(ctypes.byref(device_ptr), host_ptr, 0) + if res != 0: + raise Exception(f"cudaHostGetDevicePointer failed with error code {res}") + assert host_ptr.value == device_ptr.value + handle.tasks_finished.set() + + th = threading.Thread(target=_worker, name="cpu_cache_register", daemon=True) + handle.thread = th + th.start() + return handle + + +class AsyncRegistrationHandle: + """A handle for async host memory registration. + + - wait(): blocks until registration finishes, prints tqdm progress, and returns device pointer (int). + """ + + def __init__(self, total_tasks: int): + self.total_tasks = total_tasks + self.task_count = 0 + self.thread: Optional[threading.Thread] = None + self.tasks_finished = threading.Event() + + def wait(self): + """Block until the async registration completes. Only here we print tqdm progress.""" + last_count = 0 + desc = f"pid {os.getpid()} Registering pinned host memory (async)" + with tqdm(total=self.total_tasks, desc=desc) as pbar: + while not self.tasks_finished.is_set(): + cur = self.task_count + if cur > last_count: + pbar.update(cur - last_count) + last_count = cur + time.sleep(0.01) + # final update + cur = self.task_count + if cur > last_count: + pbar.update(cur - last_count) + last_count = cur + + if self.thread is not None and self.thread.is_alive(): + self.thread.join() + + return + + +@lru_cache(maxsize=None) +def attach_shm_kv_cache_ptr() -> int: + libc = ctypes.CDLL("/usr/lib/x86_64-linux-gnu/libc.so.6", use_errno=True) + libc.shmget.argtypes = (ctypes.c_long, ctypes.c_size_t, ctypes.c_int) + libc.shmget.restype = ctypes.c_int + libc.shmat.argtypes = (ctypes.c_int, ctypes.c_void_p, ctypes.c_int) + libc.shmat.restype = ctypes.c_void_p + + # Try to locate an existing SHM without creating a new one + args = get_env_start_args() + key = args.cpu_kv_cache_shm_id + shmid = libc.shmget(key, 0, 0) + if shmid < 0: + size = calcu_cpu_cache_meta().calcu_size() + shmid = libc.shmget(key, size, 0) + if shmid < 0: + err = ctypes.get_errno() + raise Exception(f"Error locating existing shared memory (errno={err})") + + shm_addr = libc.shmat(shmid, ctypes.c_void_p(0), 0) + if shm_addr == ctypes.c_void_p(-1).value: + err = ctypes.get_errno() + raise Exception(f"Error attaching shared memory (errno={err})") + + logger.info(f"Attached to SHM key={key}, shmid={shmid}, addr={shm_addr}") + return shm_addr diff --git a/lightllm/utils/shm_utils.py b/lightllm/utils/shm_utils.py index e5c98e5ab..7ceea6f8b 100644 --- a/lightllm/utils/shm_utils.py +++ b/lightllm/utils/shm_utils.py @@ -1,11 +1,12 @@ from multiprocessing import shared_memory from filelock import FileLock from lightllm.utils.log_utils import init_logger +from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup logger = init_logger(__name__) -def create_or_link_shm(name, expected_size, force_mode=None): +def create_or_link_shm(name, expected_size, force_mode=None, auto_cleanup=False): """ Args: name: name of the shared memory @@ -26,15 +27,15 @@ def create_or_link_shm(name, expected_size, force_mode=None): if force_mode == "create": with FileLock(lock_name): - return _force_create_shm(name, expected_size) + return _force_create_shm(name, expected_size, auto_cleanup) elif force_mode == "link": return _force_link_shm(name, expected_size) else: with FileLock(lock_name): - return _smart_create_or_link_shm(name, expected_size) + return _smart_create_or_link_shm(name, expected_size, auto_cleanup) -def _force_create_shm(name, expected_size): +def _force_create_shm(name, expected_size, auto_cleanup): """强制创建新的共享内存""" try: existing_shm = shared_memory.SharedMemory(name=name) @@ -45,6 +46,8 @@ def _force_create_shm(name, expected_size): # 创建新的共享内存 shm = shared_memory.SharedMemory(name=name, create=True, size=expected_size) + if auto_cleanup: + register_posix_shm_for_cleanup(name) return shm @@ -62,7 +65,7 @@ def _force_link_shm(name, expected_size): raise e -def _smart_create_or_link_shm(name, expected_size): +def _smart_create_or_link_shm(name, expected_size, auto_cleanup): """优先连接,不存在则创建""" try: shm = _force_link_shm(name=name, expected_size=expected_size) @@ -70,4 +73,4 @@ def _smart_create_or_link_shm(name, expected_size): except: pass - return _force_create_shm(name=name, expected_size=expected_size) + return _force_create_shm(name=name, expected_size=expected_size, auto_cleanup=auto_cleanup) diff --git a/requirements.txt b/requirements.txt index 073284198..40d0b4956 100644 --- a/requirements.txt +++ b/requirements.txt @@ -87,3 +87,4 @@ librosa==0.11.0 cuda_bindings==12.9.0 orjson==3.11.2 setproctitle==1.3.6 +xxhash==3.6.0 \ No newline at end of file diff --git a/test/benchmark/service/benchmark_qps.py b/test/benchmark/service/benchmark_qps.py index 7175be963..796bbfb97 100644 --- a/test/benchmark/service/benchmark_qps.py +++ b/test/benchmark/service/benchmark_qps.py @@ -342,7 +342,7 @@ def main(): assert args.tokenizer_path is not None model_name.append(args.tokenizer_path) - # seed_all(args.seed) + seed_all(args.seed) url = args.url tokenizer = get_tokenizer(args.tokenizer_path) if args.data_path is not None: diff --git a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py index 0462d724b..505cbbc1c 100644 --- a/unit_tests/server/router/dynamic_prompt/test_radix_cache.py +++ b/unit_tests/server/router/dynamic_prompt/test_radix_cache.py @@ -5,13 +5,13 @@ def test_case1(): tree = RadixCache("unique_name", 100, 0) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) assert ans == 0 tree.print_self() - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) assert ans == 5 tree.print_self() - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) assert ans == 8 tree.print_self() @@ -26,8 +26,8 @@ def test_case1(): def test_case2(): tree = RadixCache("unique_name", 100, 1) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) tree.print_self() tree_node, size, values = tree.match_prefix( @@ -52,8 +52,8 @@ def test_case2(): def test_case3(): tree = RadixCache("unique_name", 100, 2) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) tree.print_self() tree_node, size, values = tree.match_prefix( @@ -82,8 +82,8 @@ def test_case3(): def test_case4(): tree = RadixCache("unique_name", 100, 2) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) - ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu")) + ans, _ = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu")) tree.print_self() tree.clear_tree_nodes()