Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ dist
.vscode
tmp/
requirements-musa.txt
logs/
1 change: 0 additions & 1 deletion docs/CN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ APIServer 参数详解
- ``running_max_req_size`` 为 3
- ``batch_max_tokens`` 为 2048 (2k)
- ``chunked_prefill_size`` 为 1024 (1k)
- ``mem_fraction`` 为 0.85

.. option:: --host

Expand Down
1 change: 0 additions & 1 deletion docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Basic Configuration Parameters
- ``running_max_req_size`` to 3
- ``batch_max_tokens`` to 2048 (2k)
- ``chunked_prefill_size`` to 1024 (1k)
- ``mem_fraction`` to 0.85

.. option:: --host

Expand Down
8 changes: 4 additions & 4 deletions lightllm/common/basemodel/attention/nsa/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ def _nsa_prefill_att(
from sgl_kernel.flash_mla import flash_mla_sparse_fwd

nsa_dict = att_control.nsa_prefill_dict
topk_indices = nsa_dict["topk_indices"]
topk_mem_indices = nsa_dict["topk_mem_indices"]
softmax_scale = nsa_dict["softmax_scale"]
kv_lora_rank = nsa_dict["kv_lora_rank"]

if topk_indices.ndim == 2:
topk_indices = topk_indices.unsqueeze(1)
if topk_mem_indices.ndim == 2:
topk_mem_indices = topk_mem_indices.unsqueeze(1)

mla_out, _, _ = flash_mla_sparse_fwd(
q=q,
kv=kv,
indices=topk_indices,
indices=topk_mem_indices,
sm_scale=softmax_scale,
d_v=kv_lora_rank,
)
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/kv_cache_mem_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .allocator import KvCacheAllocator
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
Expand All @@ -9,6 +10,7 @@
from .qwen3next_mem_manager import Qwen3NextMemManager

__all__ = [
"KvCacheAllocator",
"MemoryManager",
"ReadOnlyStaticsMemoryManager",
"PPLINT4KVMemoryManager",
Expand Down
90 changes: 12 additions & 78 deletions lightllm/common/kv_cache_mem_manager/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from typing import List, Union, Tuple, Any
from typing import List, Tuple, Any, Union
from lightllm.server.pd_io_struct import KVMoveTask
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from .allocator import KvCacheAllocator
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.utils.dist_utils import get_current_rank_in_node, get_node_world_size
Expand Down Expand Up @@ -38,27 +39,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
# profile the max total token num if the size is None
self.profile_size(mem_fraction)

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._mem_state_return = torch.arange(
0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self._return_start = 0
self.mark_start = 0
self.mark_end = self.size

self.can_use_mem_size = self.size

# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
from lightllm.utils.envs_utils import get_unique_server_name
self.allocator = KvCacheAllocator(self.size)

rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self._init_buffers(
self.size,
dtype,
Expand All @@ -83,9 +65,10 @@ def profile_size(self, mem_fraction):
if self.size is not None:
return

torch.cuda.empty_cache()
world_size = dist.get_world_size()
total_memory = get_total_gpu_memory()
available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction)

available_memory = get_available_gpu_memory(world_size) * mem_fraction
cell_size = self.get_cell_size()
self.size = int(available_memory * 1024 ** 3 / cell_size)
if world_size > 1:
Expand Down Expand Up @@ -338,57 +321,13 @@ def _free_buffers(self):
self.kv_buffer = None

def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"

start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size

self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans

def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_

Args:
free_index (torch.Tensor): _description_
"""
return self.allocator.alloc(need_size)

end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"

if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index

self.mark_start -= len(free_index)

self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)

if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)

def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)
self.allocator.free_all()

def resize_mem(self, new_size):
Comment on lines 323 to 332
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the free method, please add type hints for the parameters and return values of alloc, free_all, and resize_mem.

Suggested change
def alloc(self, need_size) -> torch.Tensor:
if need_size > self.mark_end - self.mark_start:
logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
assert False, "error alloc state"
start = self.mark_start
end = self.mark_start + need_size
self.mark_start += need_size
self.can_use_mem_size -= need_size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
# 利用缓冲区返回,避免异步情况下的内存竞争
if self._return_start + need_size > self._mem_state_return.shape[0]:
self._return_start = 0
ans = self._mem_state_return[self._return_start : self._return_start + need_size]
ans.copy_(self.mem_state[start:end])
self._return_start += need_size
return ans
def free(self, free_index: Union[torch.Tensor, List[int]]):
"""_summary_
Args:
free_index (torch.Tensor): _description_
"""
return self.allocator.alloc(need_size)
end = self.mark_start
start = self.mark_start - len(free_index)
assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
if isinstance(free_index, list):
self.mem_state.numpy()[start:end] = free_index
else:
# 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
self.mem_state[start:end] = free_index
self.mark_start -= len(free_index)
self.can_use_mem_size += len(free_index)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
if self.can_use_mem_size == len(self.mem_state):
logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
return
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)
def free_all(self):
self.can_use_mem_size = len(self.mem_state)
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
self.mark_start = 0
self.mark_end = len(self.mem_state)
self.allocator.free_all()
def resize_mem(self, new_size):
def alloc(self, need_size: int) -> torch.Tensor:
return self.allocator.alloc(need_size)
def free(self, free_index: Union[torch.Tensor, List[int]]) -> None:
self.allocator.free(free_index)
def free_all(self) -> None:
self.allocator.free_all()
def resize_mem(self, new_size: int) -> None:

"""
Expand All @@ -401,13 +340,8 @@ def resize_mem(self, new_size):
layer_num = self.layer_num

self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.allocator.resize(new_size)
Comment on lines 342 to +343
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The HOLD_TOKEN_MEMINDEX attribute is initialized to self.size in __init__. When resizing the memory, this attribute should also be updated to reflect the new size, ensuring consistency for any components relying on this marker.

Suggested change
self.size = new_size
self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size
self.can_use_mem_size = self.size
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
self.allocator.resize(new_size)
self.size = new_size
self.allocator.resize(new_size)
self.HOLD_TOKEN_MEMINDEX = self.size

self.HOLD_TOKEN_MEMINDEX = self.size
self._free_buffers()
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
return
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--mem_fraction",
type=float,
default=0.9,
default=0.8,
help="""Memory usage ratio, default is 0.9, you can specify a smaller value if OOM occurs at runtime.
If max_total_token_num is not specified, it will be calculated automatically based on this value.""",
)
Expand Down
13 changes: 9 additions & 4 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
return idx


def _get_reasoning_from_request(request: ChatCompletionRequest) -> bool:
"""Judge whether the request needs reasoning"""
def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool:
"""Whether this request uses forced thinking / reasoning (parser + template)."""
from .build_prompt import tokenizer_supports_force_thinking

if not tokenizer_supports_force_thinking():
return False

reasoning_parser = get_env_start_args().reasoning_parser
if not reasoning_parser:
return False
Expand All @@ -175,7 +180,7 @@ def _process_reasoning_stream(
) -> tuple[Optional[str], str]:
"""Process reasoning content in streaming response"""
if index not in reasoning_parser_dict:
request_enable_reasoning = _get_reasoning_from_request(request)
request_enable_reasoning = _is_force_thinking_mode(request)
reasoning_parser_dict[index] = ReasoningParser(
get_env_start_args().reasoning_parser,
request.stream_reasoning,
Expand Down Expand Up @@ -376,7 +381,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
reasoning_text = None
reasoning_parser = get_env_start_args().reasoning_parser
if reasoning_parser:
request_enable_reasoning = _get_reasoning_from_request(request)
request_enable_reasoning = _is_force_thinking_mode(request)
try:
parser = ReasoningParser(
model_type=reasoning_parser,
Expand Down
9 changes: 4 additions & 5 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,14 @@ def normal_or_p_d_start(args):

# performance_mode 参数处理
if args.performance_mode == "personal":
args.running_max_req_size = 3
args.running_max_req_size = 6
args.batch_max_tokens = 2048
args.chunked_prefill_size = 1024
if args.mem_fraction > 0.82:
args.mem_fraction = 0.82
args.graph_max_batch_size = 32
args.embed_cache_storage_size = 0.8
args.graph_max_batch_size = 6
logger.info(
f"performance_mode is personal, set running_max_req_size to 3,"
f"batch_max_tokens to 2048, chunked_prefill_size to 1024, mem_fraction to 0.82,"
f"batch_max_tokens to 2048, chunked_prefill_size to 1024,"
f"graph_max_batch_size to 32"
)

Expand Down
40 changes: 40 additions & 0 deletions lightllm/server/build_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from lightllm.server.tokenizer import get_tokenizer
from lightllm.utils.log_utils import init_logger
from functools import lru_cache

logger = init_logger(__name__)

Expand Down Expand Up @@ -45,6 +46,32 @@ def init_tokenizer(args):
return


@lru_cache(maxsize=1)
def tokenizer_supports_force_thinking() -> bool:
"""Whether this tokenizer supports thinking / reasoning."""

assert tokenizer is not None

try:
ans = "thinking" in tokenizer.chat_template or "enable_thinking" in tokenizer.chat_template
logger.debug(f"chat_template: {tokenizer.chat_template}")
logger.info(f"tokenizer_supports_force_thinking : {ans}")
return ans
except:
pass

try:
ans = "thinking" in tokenizer.tokenizer.chat_template or "enable_thinking" in tokenizer.tokenizer.chat_template
logger.debug(f"tokenizer.tokenizer.chat_template: {tokenizer.tokenizer.chat_template}")
logger.info(f"tokenizer_supports_force_thinking : {ans}")
return ans
except:
pass

logger.info("tokenizer_supports_force_thinking : False")
return False


def _normalize_tool_call_arguments(messages: list) -> None:
# Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility
# Qwen35's chat template expects arguments to be a dict (uses |items filter)
Expand Down Expand Up @@ -94,6 +121,19 @@ async def build_prompt(request, tools) -> str:
if request.chat_template_kwargs:
kwargs.update(request.chat_template_kwargs)

# 修复一些parser类型是默认打开thinking,但是 tokenizer有时候不知道打开了thinking。导致
# 构建的reasoning parser 和 tokenizer 的行为不对齐导致的问题。
from .api_openai import _is_force_thinking_mode

thinking = _is_force_thinking_mode(request)

kwargs["thinking"] = thinking
kwargs["enable_thinking"] = thinking

# TODO thinking 模式应该是3种,一种是强制思考,一种是强制不思考,一种是模型自己决定的自适应
# 的思考模式。当前的代码只是实现了强制思考和强制不思考两种模式。后续要根据模型的情况,从tokenizer
# 上判断能支持的思考模式种类,再进行设置,才能具备更完备的处理。

try:
input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools)
except BaseException as e:
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class StartArgs:
tokenizer_mode: str = field(default="slow")
load_way: str = field(default="HF")
max_total_token_num: Optional[int] = field(default=None)
mem_fraction: float = field(default=0.9)
mem_fraction: float = field(default=0.8)
batch_max_tokens: Optional[int] = field(default=None)
eos_id: List[int] = field(default_factory=list)
tool_call_parser: Optional[str] = field(
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ async def _encode(

if self.args.detail_log:
logger.debug(
f"req_id: {sampling_params.group_request_id} prompt: {prompt},\n"
f"req_id: {sampling_params.group_request_id} prompt: {prompt}\n"
f"samplingparmas: {sampling_params.to_dict()}\n"
f"token_ids: {prompt_ids}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,8 @@ def _print_helper(self, node: LinearAttPagedTreeNode, indent):

def free_radix_cache_to_get_enough_token(self, need_token_num):
assert self.mem_manager is not None
if need_token_num > self.mem_manager.can_use_mem_size:
need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size
if need_token_num > self.mem_manager.allocator.can_use_mem_size:
need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size
release_mems = []
small_page_buffer_ids = []

Expand Down
10 changes: 2 additions & 8 deletions lightllm/server/router/dynamic_prompt/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,6 @@ def merge_unreferenced_nodes(self):
if merged_node:
worklist.append(merged_node)

def assert_leafs_is_right(self):
for node in self.evict_tree_set:
if node.is_leaf() and node.ref_counter == 0:
a = node.token_mem_index_value.cuda()
assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a)

def clear_tree_nodes(self):
"""
该函数只在测试时调用
Expand Down Expand Up @@ -497,8 +491,8 @@ def _print_helper(self, node: TreeNode, indent):

def free_radix_cache_to_get_enough_token(self, need_token_num):
assert self.mem_manager is not None
if need_token_num > self.mem_manager.can_use_mem_size:
need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size
if need_token_num > self.mem_manager.allocator.can_use_mem_size:
need_evict_token_num = need_token_num - self.mem_manager.allocator.can_use_mem_size
release_mems = []

def release_mem(mem_index):
Expand Down
6 changes: 3 additions & 3 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def _filter(self, finished_request_ids: List[int]):
f"free a batch state:\n"
f"radix refed token num {self.radix_cache.get_refed_tokens_num()}\n"
f"radix hold token num {self.radix_cache.get_tree_total_tokens_num()}\n"
f"mem manager can alloc token num {self.req_manager.mem_manager.can_use_mem_size}\n"
f"mem manager total size {self.req_manager.mem_manager.size}"
f"mem manager can alloc token num {self.req_manager.mem_manager.allocator.can_use_mem_size}\n"
f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of a trailing newline \n at the end of the log message string will result in an extra empty line in the output. It is better to keep the log message concise without the trailing newline, consistent with the original implementation.

Suggested change
f"mem manager total size {self.req_manager.mem_manager.allocator.size}\n"
f"mem manager total size {self.req_manager.mem_manager.allocator.size}"

)

return
Expand Down Expand Up @@ -348,7 +348,7 @@ def get_can_alloc_token_num(self):
radix_cache_unref_token_num = (
self.radix_cache.get_tree_total_tokens_num() - self.radix_cache.get_refed_tokens_num()
)
return self.req_manager.mem_manager.can_use_mem_size + radix_cache_unref_token_num
return self.req_manager.mem_manager.allocator.can_use_mem_size + radix_cache_unref_token_num

def copy_linear_att_state_to_cache_buffer(self, b_req_idx: torch.Tensor, reqs: List["InferReq"]):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask):
logger.debug(
f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n"
f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n"
f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n"
f"mem manager total size {self.backend.model.mem_manager.size}"
f"mem manager can alloc token num {self.backend.model.mem_manager.allocator.can_use_mem_size}\n"
f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in infer_batch.py, the trailing newline \n here adds an unnecessary empty line to the debug log output. Please remove it for cleaner log formatting.

Suggested change
f"mem manager total size {self.backend.model.mem_manager.allocator.size}\n"
f"mem manager total size {self.backend.model.mem_manager.allocator.size}"

f"frozened token num {frozen_token_num}\n"
f"estimated peak token num {estimated_peak_token_num}\n"
)
Expand Down
Loading
Loading