From c0bdab08e4e40ea2f8e729546659769811b406ca Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 02:25:39 +0000 Subject: [PATCH 1/3] nixl pd support qwen3.5 --- .../deepseek2_mem_manager.py | 6 + .../kv_cache_mem_manager/mem_manager.py | 9 +- .../qwen3next_mem_manager.py | 356 +++++++++++++++++- .../linear_att_cache_manager/config_objs.py | 17 +- lightllm/models/qwen3next/model.py | 6 +- lightllm/server/pd_io_struct.py | 20 +- .../server/router/model_infer/infer_batch.py | 31 +- .../pd_nixl/decode_node_impl/decode_impl.py | 21 ++ .../decode_node_impl/decode_impl_for_dp.py | 2 + .../decode_node_impl/decode_trans_process.py | 4 +- .../pd_nixl/prefill_node_impl/prefill_impl.py | 38 +- .../prefill_node_impl/prefill_impl_for_dp.py | 10 +- .../prefill_trans_process.py | 2 + 13 files changed, 484 insertions(+), 38 deletions(-) diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index d49c8d7e73..7a24a59110 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -56,7 +56,10 @@ def write_mem_to_page_kv_move_buffer( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -77,7 +80,10 @@ def read_page_kv_move_buffer_to_mem( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 0a1deba499..47364af5f9 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -102,9 +102,6 @@ def alloc_kv_move_buffer(self, max_req_total_len): return def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: - if isinstance(self, MemoryManager) and type(self) is not MemoryManager: - raise NotImplementedError("subclass need reimpl this method") - num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) self.kv_move_buffer = torch.empty( (page_num, page_size, self.layer_num, 2 * num_kv_head, self.head_dim), dtype=self.dtype, device="cuda" @@ -121,7 +118,10 @@ def write_mem_to_page_kv_move_buffer( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes @@ -150,7 +150,10 @@ def read_page_kv_move_buffer_to_mem( dp_index: int, mem_managers: List["MemoryManager"], dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, ): + assert page_kind == "kv", f"{type(self).__name__} does not support page_kind={page_kind}" cur_page = self.kv_move_buffer[page_index] pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes diff --git a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py index caca4bb621..c7ce9d96ba 100644 --- a/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/qwen3next_mem_manager.py @@ -5,7 +5,7 @@ from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.linear_att_cache_manager import LinearAttCacheConfig, LinearAttCacheManager from .operator import LinearAttMemOperator -from typing import Tuple, Any +from typing import Tuple, Any, List logger = init_logger(__name__) @@ -63,3 +63,357 @@ def _free_buffers(self): def _free_linear_att_buffers(self): self.linear_att_big_page_buffers = None return + + def write_to_shm(self, req_manager): + self.req_to_conv_state = req_manager.req_to_conv_state + self.req_to_ssm_state = req_manager.req_to_ssm_state + return super().write_to_shm(req_manager) + + def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: + kv_move_buffer = super().alloc_paged_kv_move_buffer(page_num, page_size) + Qwen3NextLinearAttPageHelper(self).assert_page_size() + return kv_move_buffer + + def write_mem_to_page_kv_move_buffer( + self, + mem_indexes, + page_index: int, + dp_index: int, + mem_managers, + dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, + ): + if page_kind == "kv": + return super().write_mem_to_page_kv_move_buffer( + mem_indexes=mem_indexes, + page_index=page_index, + dp_index=dp_index, + mem_managers=mem_managers, + dp_world_size=dp_world_size, + page_kind=page_kind, + req_idx=req_idx, + ) + assert page_kind == "linear_att_state", f"unknown page_kind={page_kind}" + assert req_idx is not None + helper = Qwen3NextLinearAttPageHelper(self) + dp_mems = helper.get_dp_mems(mem_managers, dp_index, dp_world_size) + helper.write_req_to_page(page_index=page_index, req_idx=req_idx, dp_mems=dp_mems) + return + + def read_page_kv_move_buffer_to_mem( + self, + mem_indexes, + page_index: int, + dp_index: int, + mem_managers, + dp_world_size: int, + page_kind: str = "kv", + req_idx: int = None, + ): + if page_kind == "kv": + return super().read_page_kv_move_buffer_to_mem( + mem_indexes=mem_indexes, + page_index=page_index, + dp_index=dp_index, + mem_managers=mem_managers, + dp_world_size=dp_world_size, + page_kind=page_kind, + req_idx=req_idx, + ) + assert page_kind == "linear_att_state", f"unknown page_kind={page_kind}" + assert req_idx is not None + helper = Qwen3NextLinearAttPageHelper(self) + dp_mems = helper.get_dp_mems(mem_managers, dp_index, dp_world_size) + helper.read_page_to_req(page_index=page_index, req_idx=req_idx, dp_mems=dp_mems) + return + + +class Qwen3NextLinearAttPageHelper: + def __init__(self, mem_manager: "Qwen3NextMemManager"): + self.mem_manager = mem_manager + self.linear_config = mem_manager.linear_config + self.req_to_conv_state = mem_manager.req_to_conv_state + self.req_to_ssm_state = mem_manager.req_to_ssm_state + self.global_linear_k_heads = self.linear_config.global_linear_k_heads + self.global_linear_v_heads = self.linear_config.global_linear_v_heads + + self.global_q_dim = self.global_linear_k_heads * self.linear_config.head_linear_k_dim + self.global_k_dim = self.global_q_dim + self.global_v_heads = self.global_linear_v_heads + self.global_v_dim = self.global_v_heads * self.linear_config.head_linear_v_dim + # conv state follows mixed_qkv layout: [q, k, v], each as a flat channel block. + self.conv_shape = ( + self.linear_config.linear_layer_num, + self.global_q_dim + self.global_k_dim + self.global_v_dim, + self.linear_config.conv_kernel_size - 1, + ) + self.ssm_shape = ( + self.linear_config.linear_layer_num, + self.global_v_heads, + self.linear_config.head_linear_k_dim, + self.linear_config.head_linear_v_dim, + ) + + self.conv_nbytes = ( + self.conv_shape[0] * self.conv_shape[1] * self.conv_shape[2] * self.req_to_conv_state.buffer.element_size() + ) + ssm_alignment = self.req_to_ssm_state.buffer.element_size() + # 做一下字节对齐,防止切出来的不对齐,导致一些操作的性能下降。 + self.ssm_offset = ((self.conv_nbytes + ssm_alignment - 1) // ssm_alignment) * ssm_alignment + self.ssm_nbytes = ( + self.ssm_shape[0] + * self.ssm_shape[1] + * self.ssm_shape[2] + * self.ssm_shape[3] + * self.req_to_ssm_state.buffer.element_size() + ) + self.state_nbytes = self.ssm_offset + self.ssm_nbytes + + def assert_page_size(self): + kv_move_buffer = self.mem_manager.kv_move_buffer + page_nbytes = kv_move_buffer[0].numel() * kv_move_buffer.element_size() + assert ( + page_nbytes >= self.state_nbytes + ), f"nixl kv move page bytes {page_nbytes} is smaller than global linear att state bytes {self.state_nbytes}" + return + + def get_dp_mems(self, mem_managers: List["Qwen3NextMemManager"], dp_index: int, dp_world_size: int): + dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] + assert len(dp_mems) == dp_world_size + for mem in dp_mems: + assert hasattr(mem, "req_to_conv_state") and hasattr(mem, "req_to_ssm_state") + assert mem.linear_config.linear_layer_num == self.linear_config.linear_layer_num + assert mem.linear_config.conv_kernel_size == self.linear_config.conv_kernel_size + assert mem.linear_config.head_linear_k_dim == self.linear_config.head_linear_k_dim + assert mem.linear_config.head_linear_v_dim == self.linear_config.head_linear_v_dim + assert mem.linear_config.num_linear_k_heads == self.linear_config.num_linear_k_heads + assert mem.linear_config.num_linear_v_heads == self.linear_config.num_linear_v_heads + return dp_mems + + def view_page_to_linear_att_state(self, page_index: int): + page_bytes = self.mem_manager.kv_move_buffer[page_index].view(torch.uint8).reshape(-1) + conv_page = page_bytes[0 : self.conv_nbytes].view(self.req_to_conv_state.buffer.dtype).view(self.conv_shape) + ssm_page = ( + page_bytes[self.ssm_offset : self.ssm_offset + self.ssm_nbytes] + .view(self.req_to_ssm_state.buffer.dtype) + .view(self.ssm_shape) + ) + return conv_page, ssm_page + + def write_req_to_page( + self, + page_index: int, + req_idx: int, + dp_mems: List["Qwen3NextMemManager"], + ): + conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) + req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + for tp_index, mem in enumerate(dp_mems): + self._write_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + return + + def read_page_to_req( + self, + page_index: int, + req_idx: int, + dp_mems: List["Qwen3NextMemManager"], + ): + conv_page, ssm_page = self.view_page_to_linear_att_state(page_index) + req_buffer_idx = req_idx * (get_env_start_args().mtp_step + 1) + for tp_index, mem in enumerate(dp_mems): + self._read_one_rank(mem, tp_index, req_buffer_idx, conv_page, ssm_page) + return + + def _write_one_rank( + self, + mem: "Qwen3NextMemManager", + tp_index: int, + req_buffer_idx: int, + conv_page: torch.Tensor, + ssm_page: torch.Tensor, + ): + conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] + ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + self._copy_conv_state_to_page(conv_state, conv_page, mem, tp_index) + self._copy_ssm_state_to_page(ssm_state, ssm_page, mem, tp_index) + return + + def _copy_conv_state_to_page( + self, + conv_state: torch.Tensor, + conv_page: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + local_q_heads = mem.linear_config.num_linear_k_heads + local_v_heads = mem.linear_config.num_linear_v_heads + head_k_dim = mem.linear_config.head_linear_k_dim + head_v_dim = mem.linear_config.head_linear_v_dim + + local_q_state = conv_state[:, 0 : local_q_heads * head_k_dim, :] + local_k_state = conv_state[:, local_q_heads * head_k_dim : 2 * local_q_heads * head_k_dim, :] + local_v_state = conv_state[:, 2 * local_q_heads * head_k_dim :, :] + global_q_page = conv_page[:, 0 : self.global_q_dim, :] + global_k_page = conv_page[:, self.global_q_dim : self.global_q_dim + self.global_k_dim, :] + global_v_page = conv_page[:, self.global_q_dim + self.global_k_dim :, :] + + qk_head_slice = self._get_head_slice( + tp_index, local_q_heads, self.global_linear_k_heads, mem.linear_config.tp_world_size, is_write=True + ) + if qk_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = qk_head_slice + local_dim_start = local_head_start * head_k_dim + local_dim_end = local_head_end * head_k_dim + global_dim_start = global_head_start * head_k_dim + global_dim_end = global_head_end * head_k_dim + global_q_page[:, global_dim_start:global_dim_end, :].copy_( + local_q_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + global_k_page[:, global_dim_start:global_dim_end, :].copy_( + local_k_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + + v_head_slice = self._get_head_slice( + tp_index, local_v_heads, self.global_linear_v_heads, mem.linear_config.tp_world_size, is_write=True + ) + if v_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = v_head_slice + local_dim_start = local_head_start * head_v_dim + local_dim_end = local_head_end * head_v_dim + global_dim_start = global_head_start * head_v_dim + global_dim_end = global_head_end * head_v_dim + global_v_page[:, global_dim_start:global_dim_end, :].copy_( + local_v_state[:, local_dim_start:local_dim_end, :], non_blocking=True + ) + return + + def _copy_ssm_state_to_page( + self, + ssm_state: torch.Tensor, + ssm_page: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + head_slice = self._get_head_slice( + tp_index, + mem.linear_config.num_linear_v_heads, + self.global_linear_v_heads, + mem.linear_config.tp_world_size, + is_write=True, + ) + if head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = head_slice + ssm_page[:, global_head_start:global_head_end, :, :].copy_( + ssm_state[:, local_head_start:local_head_end, :, :], + non_blocking=True, + ) + return + + def _get_head_slice( + self, + tp_index: int, + local_heads: int, + global_heads: int, + tp_world_size: int, + is_write: bool, + ): + if local_heads == 0 or global_heads == 0: + return None + total_local_heads = local_heads * tp_world_size + repeat_count = max(1, total_local_heads // global_heads) + if is_write and repeat_count > 1 and tp_index % repeat_count != 0: + return None + unique_tp_index = tp_index // repeat_count + global_head_start = unique_tp_index * local_heads + global_head_end = min(global_head_start + local_heads, global_heads) + local_head_start = 0 + local_head_end = global_head_end - global_head_start + if local_head_end <= local_head_start: + return None + return local_head_start, local_head_end, global_head_start, global_head_end + + def _copy_page_to_conv_state( + self, + conv_page: torch.Tensor, + conv_state: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + local_q_heads = mem.linear_config.num_linear_k_heads + local_v_heads = mem.linear_config.num_linear_v_heads + head_k_dim = mem.linear_config.head_linear_k_dim + head_v_dim = mem.linear_config.head_linear_v_dim + + local_q_state = conv_state[:, 0 : local_q_heads * head_k_dim, :] + local_k_state = conv_state[:, local_q_heads * head_k_dim : 2 * local_q_heads * head_k_dim, :] + local_v_state = conv_state[:, 2 * local_q_heads * head_k_dim :, :] + global_q_page = conv_page[:, 0 : self.global_q_dim, :] + global_k_page = conv_page[:, self.global_q_dim : self.global_q_dim + self.global_k_dim, :] + global_v_page = conv_page[:, self.global_q_dim + self.global_k_dim :, :] + + qk_head_slice = self._get_head_slice( + tp_index, local_q_heads, self.global_linear_k_heads, mem.linear_config.tp_world_size, is_write=False + ) + if qk_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = qk_head_slice + local_dim_start = local_head_start * head_k_dim + local_dim_end = local_head_end * head_k_dim + global_dim_start = global_head_start * head_k_dim + global_dim_end = global_head_end * head_k_dim + local_q_state[:, local_dim_start:local_dim_end, :].copy_( + global_q_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + local_k_state[:, local_dim_start:local_dim_end, :].copy_( + global_k_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + + v_head_slice = self._get_head_slice( + tp_index, local_v_heads, self.global_linear_v_heads, mem.linear_config.tp_world_size, is_write=False + ) + if v_head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = v_head_slice + local_dim_start = local_head_start * head_v_dim + local_dim_end = local_head_end * head_v_dim + global_dim_start = global_head_start * head_v_dim + global_dim_end = global_head_end * head_v_dim + local_v_state[:, local_dim_start:local_dim_end, :].copy_( + global_v_page[:, global_dim_start:global_dim_end, :], non_blocking=True + ) + return + + def _copy_page_to_ssm_state( + self, + ssm_page: torch.Tensor, + ssm_state: torch.Tensor, + mem: "Qwen3NextMemManager", + tp_index: int, + ): + head_slice = self._get_head_slice( + tp_index, + mem.linear_config.num_linear_v_heads, + self.global_linear_v_heads, + mem.linear_config.tp_world_size, + is_write=False, + ) + if head_slice is not None: + local_head_start, local_head_end, global_head_start, global_head_end = head_slice + ssm_state[:, local_head_start:local_head_end, :, :].copy_( + ssm_page[:, global_head_start:global_head_end, :, :], + non_blocking=True, + ) + return + + def _read_one_rank( + self, + mem: "Qwen3NextMemManager", + tp_index: int, + req_buffer_idx: int, + conv_page: torch.Tensor, + ssm_page: torch.Tensor, + ): + conv_state = mem.req_to_conv_state.buffer[:, req_buffer_idx, ...] + ssm_state = mem.req_to_ssm_state.buffer[:, req_buffer_idx, ...] + self._copy_page_to_conv_state(conv_page, conv_state, mem, tp_index) + self._copy_page_to_ssm_state(ssm_page, ssm_state, mem, tp_index) + return diff --git a/lightllm/common/linear_att_cache_manager/config_objs.py b/lightllm/common/linear_att_cache_manager/config_objs.py index 46ab9d2107..bc39067069 100644 --- a/lightllm/common/linear_att_cache_manager/config_objs.py +++ b/lightllm/common/linear_att_cache_manager/config_objs.py @@ -18,6 +18,8 @@ class LinearAttCacheConfig: full_att_head_dim: int # linear att 的参数 + global_linear_k_heads: int + global_linear_v_heads: int num_linear_k_heads: int num_linear_v_heads: int head_linear_k_dim: int @@ -30,7 +32,14 @@ class LinearAttCacheConfig: all_layer_num: int # 包括 linear att 和 full att 的层加起来的层数 def get_conv_dim(self): - return self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + # 第一项对应q的参数,第二项对应k的参数,第三项对应v的参数 + # 由于 k_dim = q_dim, k_heads = q_heads, 所以第一项和第二项的计算 + # 形式相同,但是实际内在含义是不同的。 + return ( + self.head_linear_k_dim * self.num_linear_k_heads + + self.head_linear_k_dim * self.num_linear_k_heads + + self.head_linear_v_dim * self.num_linear_v_heads + ) def get_conv_state_shape(self): return (self.get_conv_dim(), self.conv_kernel_size - 1) @@ -92,8 +101,10 @@ def load_from_args() -> "LinearAttCacheConfig": full_att_dtype=get_torch_dtype(args.data_type), full_att_num_kv_heads=max(1, llm_config["num_key_value_heads"] // tp_world_size), full_att_head_dim=llm_config["head_dim"], - num_linear_k_heads=llm_config["linear_num_key_heads"] // tp_world_size, - num_linear_v_heads=llm_config["linear_num_value_heads"] // tp_world_size, + global_linear_k_heads=llm_config["linear_num_key_heads"], + global_linear_v_heads=llm_config["linear_num_value_heads"], + num_linear_k_heads=max(1, llm_config["linear_num_key_heads"] // tp_world_size), + num_linear_v_heads=max(1, llm_config["linear_num_value_heads"] // tp_world_size), head_linear_k_dim=llm_config["linear_key_head_dim"], head_linear_v_dim=llm_config["linear_value_head_dim"], conv_kernel_size=llm_config["linear_conv_kernel_dim"], diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index e3c51f3617..9b5e9b7a50 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -65,8 +65,10 @@ def _init_mem_manager(self): full_att_dtype=self.data_type, full_att_num_kv_heads=self.num_kv_heads, full_att_head_dim=self.config["head_dim"], - num_linear_k_heads=self.config["linear_num_key_heads"] // self.tp_world_size_, - num_linear_v_heads=self.config["linear_num_value_heads"] // self.tp_world_size_, + global_linear_k_heads=self.config["linear_num_key_heads"], + global_linear_v_heads=self.config["linear_num_value_heads"], + num_linear_k_heads=max(1, self.config["linear_num_key_heads"] // self.tp_world_size_), + num_linear_v_heads=max(1, self.config["linear_num_value_heads"] // self.tp_world_size_), head_linear_k_dim=self.config["linear_key_head_dim"], head_linear_v_dim=self.config["linear_value_head_dim"], conv_kernel_size=self.config["linear_conv_kernel_dim"], diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 6dfa8bcbeb..35e73e3007 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -2,7 +2,7 @@ import time import copy from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union, Set +from typing import Dict, List, Optional from lightllm.server.req_id_generator import convert_sub_id_to_group_id from fastapi import WebSocket @@ -287,13 +287,22 @@ class NIXLChunckedTransTask: error_info: Optional[str] = None transfer_time_out_secs: int = 66 + page_kind: str = "kv" + # Only valid for the local task owner; remote notify copies may carry the sender-local req_idx. + req_idx: Optional[int] = None def __post_init__(self): if self.start_kv_index < 0 or self.end_kv_index < self.start_kv_index: error_info = "start_kv_index must >=0 and end_kv_index > start_kv_index" logger.error(error_info) raise ValueError(error_info) - assert len(self.mem_indexes) == (self.end_kv_index - self.start_kv_index) + if self.page_kind == "kv": + assert len(self.mem_indexes) == (self.end_kv_index - self.start_kv_index) + elif self.page_kind == "linear_att_state": + assert self.start_kv_index == self.end_kv_index + assert len(self.mem_indexes) == 0 + else: + raise ValueError(f"unknown NIXL trans page kind {self.page_kind}") self.create_time = time.time() return @@ -315,7 +324,7 @@ def transfer_time(self): return time.time() - self.start_trans_time def get_key(self) -> str: - return f"{self.request_id}_{self.start_kv_index}_{self.end_kv_index}" + return f"{self.request_id}_{self.req_idx}_{self.page_kind}_{self.start_kv_index}_{self.end_kv_index}" def to_str(self): obj: NIXLChunckedTransTask = copy.copy(self) @@ -331,8 +340,13 @@ def to_str(self): return obj.__str__() def transfer_kv_num(self): + if self.page_kind != "kv": + return 0 return self.end_kv_index - self.start_kv_index + def need_transfer_page(self): + return self.page_kind != "kv" or self.transfer_kv_num() != 0 + def createRetObj(self) -> "NIXLChunckedTransTaskRet": ret = NIXLChunckedTransTaskRet( request_id=self.request_id, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c1..6c4b19e65c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -372,20 +372,23 @@ def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: L big_page_buffer_ids.append(-1) assert len(b_req_idx) == len(big_page_buffer_ids) - big_page_buffer_ids = torch.tensor(big_page_buffer_ids, dtype=torch.int32, requires_grad=False, device="cpu") - big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) - - from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer - - copy_linear_att_state_to_kv_buffer( - b_req_idx=b_req_idx, - big_page_buffer_ids=big_page_buffer_ids, - gpu_conv_state=self.req_manager.req_to_conv_state.buffer, - gpu_ssm_state=self.req_manager.req_to_ssm_state.buffer, - cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, - cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, - mtp_step=self.args.mtp_step, - ) + if any(buffer_id != -1 for buffer_id in big_page_buffer_ids): + big_page_buffer_ids = torch.tensor( + big_page_buffer_ids, dtype=torch.int32, requires_grad=False, device="cpu" + ) + big_page_buffer_ids = big_page_buffer_ids.cuda(non_blocking=True) + + from lightllm.common.basemodel.triton_kernel.linear_att_copy import copy_linear_att_state_to_kv_buffer + + copy_linear_att_state_to_kv_buffer( + b_req_idx=b_req_idx, + big_page_buffer_ids=big_page_buffer_ids, + gpu_conv_state=self.req_manager.req_to_conv_state.buffer, + gpu_ssm_state=self.req_manager.req_to_ssm_state.buffer, + cpu_kv_conv_state=self.radix_cache.linear_att_big_page_buffers.conv_state_cache.buffer, + cpu_kv_ssm_state=self.radix_cache.linear_att_big_page_buffers.ssm_state_cache.buffer, + mtp_step=self.args.mtp_step, + ) assert not self.args.disable_chunked_prefill, "chunked prefill mode must be enabled for linear att mixed model" diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 3ffa15b154..bba3d11965 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -150,6 +150,17 @@ def _decode_node_gen_trans_tasks(self, req_obj: InferReq): req_obj.nixl_trans_kv_start_index += cur_page_size req_obj.cur_kv_len += len(mem_indexes) + + # 如果当前是linear att 混合模型,则需要创建一个linear att 状态的传输任务 + if g_infer_context.is_linear_att_mixed_model: + self._create_nixl_trans_task( + req_obj=req_obj, + mem_indexes=[], + kv_start_index=input_len, + kv_end_index=input_len, + group=group, + page_kind="linear_att_state", + ) else: assert req_obj.cur_kv_len == input_len - 1 @@ -175,6 +186,7 @@ def _create_nixl_trans_task( kv_start_index: int, kv_end_index: int, group: NIXLChunckedTransTaskGroup, + page_kind: str = "kv", ): # 确定传输设备 if req_obj.nixl_trans_device_id == -1: @@ -184,6 +196,13 @@ def _create_nixl_trans_task( # only self.is_master_in_dp will be used. self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size + if page_kind == "kv": + req_idx = None + elif page_kind == "linear_att_state": + req_idx = req_obj.req_idx + else: + raise ValueError(f"unknown NIXL trans page kind {page_kind}") + trans_task = NIXLChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, @@ -205,6 +224,8 @@ def _create_nixl_trans_task( decode_page_reg_desc=None, first_gen_token_id=None, first_gen_token_logprob=None, + page_kind=page_kind, + req_idx=req_idx, ) group.task_list.append(trans_task) req_obj.nixl_pd_task_num += 1 diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py index 8bf0dd7c51..dc46d795a9 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl_for_dp.py @@ -37,6 +37,7 @@ def _create_nixl_trans_task( kv_start_index: int, kv_end_index: int, group: NIXLChunckedTransTaskGroup, + page_kind: str = "kv", ): return NIXLDecodeNode._create_nixl_trans_task( self, @@ -45,4 +46,5 @@ def _create_nixl_trans_task( kv_start_index=kv_start_index, kv_end_index=kv_end_index, group=group, + page_kind=page_kind, ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index aeac6b97d8..776f41f24a 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -203,7 +203,7 @@ def dispatch_task_loop(self): with self.waiting_dict_lock: for task in trans_task_group.task_list: - if task.transfer_kv_num() != 0: + if task.need_transfer_page(): self.waiting_dict[task.get_key()] = task else: task.start_trans_time = time.time() @@ -385,6 +385,8 @@ def read_page_to_mems_loop(self): dp_index=trans_task.decode_dp_index, mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, + page_kind=trans_task.page_kind, + req_idx=trans_task.req_idx, ) copy_end_event.record(self.copy_cuda_stream) self.success_queue.put((copy_end_event, copy_start_event, trans_task)) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index b75c60b8ca..87d72df75c 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -1,6 +1,6 @@ import torch.multiprocessing as mp import random -from typing import List, Tuple, Optional +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.server.pd_io_struct import NIXLChunckedTransTask from lightllm.utils.log_utils import init_logger @@ -76,6 +76,15 @@ def _prefill_chuncked_handle_func( break if prefill_finished and len(trans_task_list) != 0 and output_len == 1: + if g_infer_context.is_linear_att_mixed_model: + trans_task_list.append( + self._create_nixl_trans_task( + req_obj=req_obj, + kv_start_index=input_len, + kv_end_index=input_len, + page_kind="linear_att_state", + ) + ) trans_task_list[-1].first_gen_token_id = next_token_id trans_task_list[-1].first_gen_token_logprob = next_token_prob @@ -85,7 +94,11 @@ def _prefill_chuncked_handle_func( return def _create_nixl_trans_task( - self, req_obj: InferReq, kv_start_index: int, kv_end_index: int + self, + req_obj: InferReq, + kv_start_index: int, + kv_end_index: int, + page_kind: str = "kv", ) -> NIXLChunckedTransTask: # 确定传输设备 if req_obj.nixl_trans_device_id == -1: @@ -95,12 +108,19 @@ def _create_nixl_trans_task( self.nixl_iter_device_id = (self.nixl_iter_device_id + 1) % self.node_world_size nixl_decode_node_info = req_obj.sampling_param.nixl_decode_node - mem_indexes = ( - self.model.req_manager.req_to_token_indexs[req_obj.req_idx, kv_start_index:kv_end_index] - .detach() - .cpu() - .tolist() - ) + if page_kind == "kv": + mem_indexes = ( + self.model.req_manager.req_to_token_indexs[req_obj.req_idx, kv_start_index:kv_end_index] + .detach() + .cpu() + .tolist() + ) + req_idx = None + elif page_kind == "linear_att_state": + mem_indexes = [] + req_idx = req_obj.req_idx + else: + raise ValueError(f"unknown NIXL trans page kind {page_kind}") trans_task = NIXLChunckedTransTask( request_id=req_obj.req_id, start_kv_index=kv_start_index, @@ -122,6 +142,8 @@ def _create_nixl_trans_task( decode_page_reg_desc=nixl_decode_node_info.page_reg_desc, first_gen_token_id=None, first_gen_token_logprob=None, + page_kind=page_kind, + req_idx=req_idx, ) req_obj.nixl_pd_task_num += 1 return trans_task diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py index eed98399e7..daa041afea 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl_for_dp.py @@ -1,5 +1,5 @@ import torch.multiprocessing as mp -from typing import List, Tuple, Optional +from typing import List, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq from lightllm.utils.log_utils import init_logger from .prefill_impl import NIXLChunckedPrefillForPrefillNode, NIXLChunckedTransTask @@ -31,8 +31,12 @@ def _prefill_chuncked_handle_func( ) def _create_nixl_trans_task( - self, req_obj: InferReq, kv_start_index: int, kv_end_index: int + self, req_obj: InferReq, kv_start_index: int, kv_end_index: int, page_kind: str = "kv" ) -> NIXLChunckedTransTask: return NIXLChunckedPrefillForPrefillNode._create_nixl_trans_task( - self, req_obj=req_obj, kv_start_index=kv_start_index, kv_end_index=kv_end_index + self, + req_obj=req_obj, + kv_start_index=kv_start_index, + kv_end_index=kv_end_index, + page_kind=page_kind, ) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index f9dcd4d3eb..cd124445ea 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -190,6 +190,8 @@ def local_copy_kv_loop(self): dp_index=trans_task.prefill_dp_index, mem_managers=self.mem_managers, dp_world_size=self.dp_world_size, + page_kind=trans_task.page_kind, + req_idx=trans_task.req_idx, ) sync_event = torch.cuda.Event() sync_event.record() From d6bf28b4f8daac8168d5b9571f50ea7acdb2273f Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 06:17:23 +0000 Subject: [PATCH 2/3] fix --- lightllm/server/pd_io_struct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 35e73e3007..fe0259855d 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -324,7 +324,7 @@ def transfer_time(self): return time.time() - self.start_trans_time def get_key(self) -> str: - return f"{self.request_id}_{self.req_idx}_{self.page_kind}_{self.start_kv_index}_{self.end_kv_index}" + return f"{self.request_id}_{self.page_kind}_{self.start_kv_index}_{self.end_kv_index}" def to_str(self): obj: NIXLChunckedTransTask = copy.copy(self) From 79995138798eb3a53353e19dbba08f310d606128 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 9 Jun 2026 06:21:49 +0000 Subject: [PATCH 3/3] fix --- .../test_model/qwen3.5-0.8b-pd-nixl/SKILL.md | 313 ++++++++++++++++++ .../check_nvidia_peermem.sh | 43 +++ 2 files changed, 356 insertions(+) create mode 100644 skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md create mode 100755 skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md new file mode 100644 index 0000000000..983f76e551 --- /dev/null +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/SKILL.md @@ -0,0 +1,313 @@ +--- +name: test-model-qwen3.5-0.8b-pd-nixl +description: >- + LightLLM Qwen3.5-0.8B PD disaggregation over NIXL gsm8k: pd_master on 8089, + nixl_prefill on 8001, nixl_decode on 8002. Supports TP1 and TP2 runs by setting + TP / PREFILL_CUDA_DEVICES / DECODE_CUDA_DEVICES. Qwen3.5 has linear-attention + state transfer; use --nixl_pd_kv_page_size 2048 and a large enough page_num + such as 256. lm_eval hits pd_master URL. Requires UCX/RDMA env, nvidia_peermem + check, curl warmup before lm_eval, registration wait in pd_master.log, and + summary.txt. Includes optional repeated-prompt decode cache probe for linear-att + page-boundary behavior. +--- + +# Qwen3.5-0.8B **PD 分离(NIXL)** 本地 GSM8K 评测 + +**测试标识**:同一 **`MODEL_DIR`(Qwen3.5-0.8B)** 下拆三条 `api_server` 进程: +**`pd_master`**、**`nixl_prefill`**、**`nixl_decode`**。评测和 warmup 只访问 +**`pd_master` 的 HTTP 端口 `8089`**。 + +Qwen3.5 与 Qwen3-8B 的关键差异: + +| 项 | Qwen3.5-0.8B NIXL PD 要点 | +|---|---| +| linear-att 状态 | PD 传输除了 KV page,还会传 `linear_att_state` 特殊页 | +| NIXL page size | 建议固定 **`--nixl_pd_kv_page_size 2048`**;`1024` 可能不足以容纳 linear-att 状态 | +| page num | 建议 **`--nixl_pd_kv_page_num 256`** 起步,避免 page 池过小干扰评测 | +| cache 判断 | repeated prompt 可能只在 prefill 侧命中,decode 侧不一定 decode-only 命中 | + +## 日志目录 + +每轮使用独立 `LOG_DIR`,至少保留: + +- `summary.txt` +- `pd_master.log` +- `prefill.log` +- `decode.log` +- `curl_warmup.log` +- `eval_gsm8k.log` + +建议命名: + +```bash +export LOG_DIR="/mtc/wzj/lightllm_dev2/LightLLM/test/benchmark/static_inference/log/qwen35_pd_nixl_$(date +%Y%m%d_%H%M%S)" +mkdir -p "${LOG_DIR}" +``` + +## 启动前检查 + +1. **模型目录**:优先使用 `MODEL_DIR=/mtc/models/Qwen3.5-0.8B`;不存在时再改成本机实际路径。 +2. **端口**:确认 `8089`、`8001`、`8002` 空闲。 +3. **显卡**:TP1 需要 prefill/decode 各 1 张卡;TP2 需要 prefill/decode 各 2 张卡,互不重叠。 +4. **代理**:启动服务和评测前清空 `http_proxy` / `https_proxy`;评测设置 `no_proxy`。 +5. **UCX/RDMA**:prefill/decode 启动前设置 `UCX_NET_DEVICES`、`UCX_TLS`。本机若默认 UCX 打到 `mlx5_8` 报 `Address not valid`,可显式使用 `mlx5_0:1` 到 `mlx5_7:1`。 +6. **nvidia_peermem**:运行本目录的 `check_nvidia_peermem.sh`,结果写入 `summary.txt`。 +7. **MPS**:如需更稳定的高并发/传输性能,可在启动服务前开启 NVIDIA MPS,并把开启状态写入 `summary.txt`。 + +## 变量配置 + +### TP2 推荐配置 + +```bash +export MODEL_DIR=/mtc/models/Qwen3.5-0.8B +export MODEL_NAME='qwen/Qwen3.5-0.8B' +export TP=2 +export PREFILL_CUDA_DEVICES='0,1' +export DECODE_CUDA_DEVICES='2,3' +export NIXL_PD_KV_PAGE_SIZE=2048 +export NIXL_PD_KV_PAGE_NUM=256 +export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" +export HOST="${PD_MASTER_IP}" +``` + +### TP1 快速验证配置 + +```bash +export MODEL_DIR=/mtc/models/Qwen3.5-0.8B +export MODEL_NAME='qwen/Qwen3.5-0.8B' +export TP=1 +export PREFILL_CUDA_DEVICES='4' +export DECODE_CUDA_DEVICES='5' +export NIXL_PD_KV_PAGE_SIZE=2048 +export NIXL_PD_KV_PAGE_NUM=256 +export PD_MASTER_IP="$(hostname -I | awk '{print $1}')" +export HOST="${PD_MASTER_IP}" +``` + +### UCX 示例 + +按本机拓扑调整,不要盲目照抄其它机器: + +```bash +export UCX_NET_DEVICES='mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1' +export UCX_TLS=rc,cuda,gdr_copy +``` + +## 启动命令 + +先写入基础信息: + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +{ + echo "MODEL_DIR=${MODEL_DIR}" + echo "MODEL_NAME=${MODEL_NAME}" + echo "TP=${TP}" + echo "PREFILL_CUDA_DEVICES=${PREFILL_CUDA_DEVICES}" + echo "DECODE_CUDA_DEVICES=${DECODE_CUDA_DEVICES}" + echo "NIXL_PD_KV_PAGE_SIZE=${NIXL_PD_KV_PAGE_SIZE}" + echo "NIXL_PD_KV_PAGE_NUM=${NIXL_PD_KV_PAGE_NUM}" + echo "PD_MASTER_IP=${PD_MASTER_IP}" + echo "HOST=${HOST}" + echo "UCX_NET_DEVICES=${UCX_NET_DEVICES-}" + echo "UCX_TLS=${UCX_TLS-}" +} | tee "${LOG_DIR}/summary.txt" + +bash skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh >> "${LOG_DIR}/summary.txt" 2>&1 +``` + +### 1. 启动 `pd_master` + +```bash +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode pd_master \ + --host "${PD_MASTER_IP}" \ + --port 8089 \ + >> "${LOG_DIR}/pd_master.log" 2>&1 & +``` + +等待 `8089` listen 后再启动节点。 + +### 2. 启动 `nixl_prefill` + +```bash +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${PREFILL_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode nixl_prefill \ + --tp "${TP}" \ + --dp 1 \ + --host "${HOST}" \ + --port 8001 \ + --disable_cudagraph \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + --nixl_pd_kv_page_size "${NIXL_PD_KV_PAGE_SIZE}" \ + --nixl_pd_kv_page_num "${NIXL_PD_KV_PAGE_NUM}" \ + >> "${LOG_DIR}/prefill.log" 2>&1 & +``` + +### 3. 启动 `nixl_decode` + +```bash +LOADWORKER=18 CUDA_VISIBLE_DEVICES="${DECODE_CUDA_DEVICES}" \ +nohup python -m lightllm.server.api_server \ + --model_dir "${MODEL_DIR}" \ + --run_mode nixl_decode \ + --tp "${TP}" \ + --dp 1 \ + --host "${HOST}" \ + --port 8002 \ + --pd_master_ip "${PD_MASTER_IP}" \ + --pd_master_port 8089 \ + --nixl_pd_kv_page_size "${NIXL_PD_KV_PAGE_SIZE}" \ + --nixl_pd_kv_page_num "${NIXL_PD_KV_PAGE_NUM}" \ + >> "${LOG_DIR}/decode.log" 2>&1 & +``` + +## 就绪判定 + +不要只看端口。必须等待 `pd_master.log` 同时出现: + +```text +mode: nixl_prefill ... registed +mode: nixl_decode ... registed +``` + +可用命令: + +```bash +rg 'mode: nixl_prefill .* registed|mode: nixl_decode .* registed|ERROR|Traceback|Exception' "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" +``` + +## Warmup + +`lm_eval` 前必须先打一次 `pd_master`: + +```bash +curl -sS -w "\nhttp_code:%{http_code}\n" -X POST "http://${PD_MASTER_IP}:8089/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"${MODEL_NAME}\",\"prompt\":\"warmup\",\"max_tokens\":16,\"temperature\":0}" \ + | tee "${LOG_DIR}/curl_warmup.log" +``` + +期望 `http_code:200`。失败时先查 `pd_master.log` / `prefill.log` / `decode.log`,不要直接跑全量评测。 + +## GSM8K 评测 + +默认并发和 batch 使用 64,避免高并发掩盖关键问题;压测时再提高。 + +```bash +export http_proxy= +export https_proxy= +export no_proxy=localhost,127.0.0.1,0.0.0.0,::1,${PD_MASTER_IP} + +HF_ALLOW_CODE_EVAL=1 HF_DATASETS_OFFLINE=0 \ +lm_eval --model local-completions \ + --model_args "model=${MODEL_NAME},base_url=http://${PD_MASTER_IP}:8089/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False,tokenizer=${MODEL_DIR}" \ + --tasks gsm8k \ + --batch_size 64 \ + --confirm_run_unsafe_code \ + >> "${LOG_DIR}/eval_gsm8k.log" 2>&1 +``` + +提取结果: + +```bash +rg -n 'flexible-extract|strict-match|exact_match|Traceback|ERROR|can not find waiting WRITE task|has_error=True' \ + "${LOG_DIR}/eval_gsm8k.log" "${LOG_DIR}/pd_master.log" "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" \ + | tee -a "${LOG_DIR}/summary.txt" +``` + +参考正常结果: + +| 场景 | 参考精度 | +|---|---| +| TP1 NIXL PD | `flexible-extract exact_match ~= 0.332`,`strict-match exact_match ~= 0.327` | +| TP2 NIXL PD | `flexible-extract exact_match ~= 0.331`,`strict-match exact_match ~= 0.328` | + +## 可选:decode-only cache 命中探针 + +这个探针用于确认重复 prompt 是否在 decode 节点全命中。Qwen3.5 的 linear-att cache 以 +`linear_att_hash_page_size` 为粒度,默认 `512`。历史观察显示: + +- prefill 侧会按 512 token 粒度逐步命中,例如 513 的第二次可命中 512。 +- decode 侧可能仍为 `gpu cache hit: False`、`gpu_prompt_cache_len:0`。 +- 只要 decode 未全命中,仍会出现 `recv WRITE request from prefill` 和 `linear_att_state` 传输。 + +### 简单重复 prompt + +在同一套服务生命周期内连续请求两次相同 prompt: + +```bash +PROMPT_FILE="${LOG_DIR}/repeat_prompt.txt" +python3 - <<'PY' "${MODEL_DIR}" "${PROMPT_FILE}" +from transformers import AutoTokenizer +import sys +tok = AutoTokenizer.from_pretrained(sys.argv[1], trust_remote_code=True) +target = 2049 +s = "Qwen3.5 linear attention cache boundary probe. " +unit = " Repeatable cache probe sentence." +while len(tok.encode(s, add_special_tokens=False)) < target: + s += unit +open(sys.argv[2], "w").write(s) +print(len(tok.encode(s, add_special_tokens=False))) +PY + +for i in 1 2; do + curl -sS -X POST "http://${PD_MASTER_IP}:8089/v1/completions" \ + -H "Content-Type: application/json" \ + -d "{\"model\":\"${MODEL_NAME}\",\"prompt\":$(python3 -c 'import json,sys; print(json.dumps(open(sys.argv[1]).read()))' "${PROMPT_FILE}"),\"max_tokens\":4,\"temperature\":0}" \ + > "${LOG_DIR}/repeat_${i}.json" + sleep 2 +done +``` + +### 判定信号 + +```bash +rg -n 'gpu cache hit:|recv WRITE request from prefill|start WRITE to decode node|linear_att_state|trans task ret success' \ + "${LOG_DIR}/prefill.log" "${LOG_DIR}/decode.log" \ + | tee -a "${LOG_DIR}/summary.txt" +``` + +decode-only 全命中的期望信号: + +| 日志 | 期望 | +|---|---| +| `decode.log` | `gpu cache hit: True` | +| `decode.log` | `gpu_prompt_cache_len` 接近 `prompt_tokens` 或至少 `input_len - cur_kv_len <= 1` | +| `decode.log` | 不再出现真实 `recv WRITE request from prefill` | +| `prefill.log` | 不再出现对应请求的 `start WRITE to decode node` | + +如果 decode 仍是 `gpu cache hit: False gpu_prompt_cache_len:0`,则说明没有进入 decode-only 命中路径。 + +## 常见问题 + +| 现象 | 处理 | +|---|---| +| `NIXL_ERR_BACKEND` / `uct_iface_open(rc_verbs/mlx5_8:1) failed: Address not valid` | 显式设置可用 `UCX_NET_DEVICES`,例如避开 `mlx5_8/9` | +| `digest sent was rejected` | 多为快速重启后的共享内存 / multiprocessing authkey 残留;清理端口和残留 `lightllm::...` worker 后重启 | +| `can not find waiting WRITE task` | 检查 NIXL notify key、abort 日志、以及 `pd_io_struct.py` 中 key 是否包含进程本地 `req_idx` | +| 1024 page size 失败 | Qwen3.5 linear-att state 页可能放不下;使用 `--nixl_pd_kv_page_size 2048` | +| 第二次同 prompt 仍走 WRITE | 可能是 decode 侧没有建立可复用 cache,或 linear-att 尾块状态无法全命中 | + +## 收尾 + +结束后释放本轮服务: + +```bash +fuser -k 8089/tcp 8001/tcp 8002/tcp || true +``` + +如仍有显存占用,检查残留 worker: + +```bash +ps -eo pid,ppid,stat,cmd | rg 'lightllm::|api_server|hypercorn' +nvidia-smi --query-compute-apps=pid,process_name,used_memory --format=csv +``` + diff --git a/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh new file mode 100755 index 0000000000..6c0fbbc118 --- /dev/null +++ b/skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Check nvidia_peermem (GPUDirect RDMA) for NIXL PD / UCX over IB. +# Usage: bash skills/test_model/qwen3.5-0.8b-pd-nixl/check_nvidia_peermem.sh [LOG_DIR] +# LOG_DIR optional: scan prefill.log / decode.log for UCX GPUDirect lines. +set -euo pipefail + +LOG_DIR="${1:-}" +FAIL=0 + +echo "=== nvidia_peermem ===" + +if lsmod 2>/dev/null | awk '{print $1}' | grep -qx nvidia_peermem; then + ver="$(cat /sys/module/nvidia_peermem/version 2>/dev/null || echo '?')" + echo "OK: module loaded (version ${ver})" +else + echo "FAIL: nvidia_peermem not loaded" + FAIL=1 +fi + +if [[ -n "$LOG_DIR" ]]; then + for f in prefill.log decode.log; do + [[ -f "${LOG_DIR}/${f}" ]] || continue + if grep -q 'GPUDirect RDMA is not detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "FAIL: ${f} -> GPUDirect RDMA is not detected (restart services after modprobe)" + FAIL=1 + elif grep -q 'GPUDirect RDMA is detected' "${LOG_DIR}/${f}" 2>/dev/null; then + echo "OK: ${f} -> GPUDirect RDMA is detected" + fi + done +fi + +if [[ "$FAIL" -ne 0 ]]; then + cat <<'EOF' + +Enable GPUDirect RDMA: + sudo modprobe nvidia_peermem + lsmod | grep nvidia_peermem + # cross-node: run on every host; then restart nixl_prefill / nixl_decode +EOF + exit 1 +fi + +exit 0