Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ async def _collect_generation_results(
"text": request_output,
"logprob": metadata.get("logprob", None),
"id": metadata.get("id", None),
"top_logprobs": metadata.get("top_logprobs", None),
}
token_infos.append(token_info)

Expand Down Expand Up @@ -693,16 +694,30 @@ def _build_logprobs_data(result: Dict, request: CompletionRequest, tokenizer) ->
all_tokens = []
all_token_logprobs = []
all_text_offsets = []
all_top_logprobs = []
offset = 0

def add_tokens_to_logprobs(token_ids=None, token_infos=None, logprob_map=None):
nonlocal offset

def add_single_token(token_text: str, logprob: float):
def add_single_token(token_text: str, logprob: float, top_logprobs: List[Dict[int, float]] = None):
nonlocal offset
all_tokens.append(token_text)
all_token_logprobs.append(logprob)
all_text_offsets.append(offset)
if top_logprobs is not None:
formatted_top_logprobs = {}
for item in top_logprobs:
for t_id, t_prob in item.items():
t_text = tokenizer.decode([t_id], skip_special_tokens=False)
formatted_top_logprobs[t_text] = t_prob
Comment on lines +709 to +713
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation decodes token IDs one by one inside a loop. This can be inefficient, especially for a large number of top logprobs. You can improve performance by collecting all token IDs and using tokenizer.batch_decode to process them in a single call.

Suggested change
formatted_top_logprobs = {}
for item in top_logprobs:
for t_id, t_prob in item.items():
t_text = tokenizer.decode([t_id], skip_special_tokens=False)
formatted_top_logprobs[t_text] = t_prob
token_ids = [next(iter(item)) for item in top_logprobs]
log_probs = [next(iter(item.values())) for item in top_logprobs]
token_texts = tokenizer.batch_decode(
[[tid] for tid in token_ids], skip_special_tokens=False
)
formatted_top_logprobs = dict(zip(token_texts, log_probs))

all_top_logprobs.append(formatted_top_logprobs)
else:
if logprob is not None:
all_top_logprobs.append({token_text: logprob})
else:
all_top_logprobs.append(None)

offset += len(token_text)

if token_ids is not None:
Expand All @@ -712,7 +727,7 @@ def add_single_token(token_text: str, logprob: float):
add_single_token(token_text, logprob)
elif token_infos is not None:
for token_info in token_infos:
add_single_token(token_info["text"], token_info["logprob"])
add_single_token(token_info["text"], token_info["logprob"], token_info.get("top_logprobs", None))

# 处理 echo 模式下的 prompt tokens
if request.echo and result.get("prompt_logprobs") is not None:
Expand Down Expand Up @@ -743,18 +758,9 @@ def add_single_token(token_text: str, logprob: float):
if result.get("token_infos"):
add_tokens_to_logprobs(token_infos=result["token_infos"])

top_logprobs_list = []
for i, (token, logprob) in enumerate(zip(all_tokens, all_token_logprobs)):
if logprob is not None:
# TODO: 标准实现需要从后端获取top-k个logprobs数据
# 目前后端不支持,只能获取所选token的logprobs
top_logprobs_list.append({token: logprob})
else:
top_logprobs_list.append(None)

return {
"tokens": all_tokens,
"token_logprobs": all_token_logprobs,
"top_logprobs": top_logprobs_list,
"top_logprobs": all_top_logprobs,
"text_offset": all_text_offsets,
}
25 changes: 25 additions & 0 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

logger = init_logger(__name__)

MAX_TOP_K_LOGPROBS = 20


class FinishStatus(ctypes.Structure):
_pack_ = 4
Expand Down Expand Up @@ -170,6 +172,7 @@ def init(
self.input_len = len(prompt_ids)
self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe
self.create_logprobs_shm_array()
self.create_top_logprobs_shm_array()
self.create_prompt_ids_shm_array()
self.chunked_prefill_size = chunked_prefill_size
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
Expand Down Expand Up @@ -218,13 +221,35 @@ def create_logprobs_shm_array(self):
self.shm_logprobs.create_shm()
return

def create_top_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}"
self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32)
self.shm_top_logprobs_ids.create_shm()

name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}"
self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32)
self.shm_top_logprobs_val.create_shm()
return

def link_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name = f"{service_uni_name}_shm_logprobs_{self.index_in_shm_mem}"
self.shm_logprobs = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.float32)
self.shm_logprobs.link_shm()
return

def link_top_logprobs_shm_array(self):
service_uni_name = get_unique_server_name()
name_ids = f"{service_uni_name}_shm_top_logprobs_ids_{self.index_in_shm_mem}"
self.shm_top_logprobs_ids = ShmArray(name_ids, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.int32)
self.shm_top_logprobs_ids.link_shm()

name_val = f"{service_uni_name}_shm_top_logprobs_val_{self.index_in_shm_mem}"
self.shm_top_logprobs_val = ShmArray(name_val, (self.alloc_shm_numpy_len, MAX_TOP_K_LOGPROBS), dtype=np.float32)
self.shm_top_logprobs_val.link_shm()
return
Comment on lines +224 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The methods create_top_logprobs_shm_array and link_top_logprobs_shm_array contain a lot of duplicated code. The only difference is calling create_shm() vs link_shm(). This could be refactored into a helper method to reduce redundancy and improve maintainability. For example, a private helper method could handle the initialization of the ShmArray objects, while the public methods would just call the appropriate create_shm() or link_shm() on them.


def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()

Expand Down
12 changes: 12 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
from lightllm.server.core.objs.req import MAX_TOP_K_LOGPROBS
from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
Expand Down Expand Up @@ -730,6 +731,17 @@ async def handle_loop(self):
"cpu_prompt_cache_len": req.cpu_prompt_cache_len,
"mtp_accepted_token_num": req.mtp_accepted_token_num,
}

top_k_ids = req.shm_top_logprobs_ids.arr[src_index]
top_k_vals = req.shm_top_logprobs_val.arr[src_index]
top_logprobs = []
for i in range(MAX_TOP_K_LOGPROBS):
if top_k_vals[i] == -float("inf"):
break
top_logprobs.append({int(top_k_ids[i]): float(top_k_vals[i])})
if top_logprobs:
metadata["top_logprobs"] = top_logprobs

if self.args.return_all_prompt_logprobs:
metadata.update(req.get_all_prompt_metadata())
if self.args.use_reward_model:
Expand Down
24 changes: 22 additions & 2 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from lightllm.common.req_manager import ReqManager
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
from lightllm.server.core.objs.req import MAX_TOP_K_LOGPROBS
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
from lightllm.utils.log_utils import init_logger
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
Expand Down Expand Up @@ -361,6 +362,7 @@ def _init_all_state(self):
self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index)
self.shm_req.link_prompt_ids_shm_array()
self.shm_req.link_logprobs_shm_array()
self.shm_req.link_top_logprobs_shm_array()
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)

# 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置
Expand Down Expand Up @@ -453,10 +455,26 @@ def get_chuncked_input_token_len(self):
chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
return chunked_end

def set_next_gen_token_id(self, next_token_id: int, logprob: float, output_len: int):
def set_next_gen_token_id(
self,
next_token_id: int,
logprob: float,
output_len: int,
top_k_ids: List[int] = None,
top_k_logprobs: List[float] = None,
):
index = self.shm_req.input_len + output_len
self.shm_req.shm_prompt_ids.arr[index - 1] = next_token_id
self.shm_req.shm_logprobs.arr[index - 1] = logprob

if top_k_ids is not None and top_k_logprobs is not None:
k = min(len(top_k_ids), MAX_TOP_K_LOGPROBS)
self.shm_req.shm_top_logprobs_ids.arr[index - 1, :k] = top_k_ids[:k]
self.shm_req.shm_top_logprobs_val.arr[index - 1, :k] = top_k_logprobs[:k]
# Zero out the rest if any
if k < MAX_TOP_K_LOGPROBS:
self.shm_req.shm_top_logprobs_ids.arr[index - 1, k:] = 0
self.shm_req.shm_top_logprobs_val.arr[index - 1, k:] = -float("inf")
return

def update_mtp_accepted_token_num(self, accept_token_num: int):
Expand Down Expand Up @@ -528,6 +546,8 @@ def handle(
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]],
is_master_in_dp: bool,
nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None,
top_k_ids: List[int] = None,
top_k_logprobs: List[float] = None,
):
# nixl_prefill_chuncked_handle_func 主要是为了处理 nixl prefill 模式下
# 分块 prefill 后,形成对应的pd 分块传输处理。
Expand All @@ -540,7 +560,7 @@ def handle(
req_obj = self.req_obj
shm_req = req_obj.shm_req
finish_status = req_obj.finish_status
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob, self.output_len)
req_obj.set_next_gen_token_id(next_token_id, next_token_logprob, self.output_len, top_k_ids, top_k_logprobs)

# 这里提前判定的主要作用是:
# 在 mtp mode 下,可以存在同一个 req 对象的多次处理,
Expand Down
48 changes: 39 additions & 9 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,13 @@ def init_mtp_draft_model(self, main_kvargs: dict):
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
return

def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor, next_token_logprobs: torch.Tensor):
def _async_copy_next_token_infos_to_pin_mem(
self,
next_token_ids: torch.Tensor,
next_token_logprobs: torch.Tensor,
top_k_ids: torch.Tensor = None,
top_k_logprobs: torch.Tensor = None,
):
"""
这个函数会把next token id和logprobs保存到pinned memory中
这样可以保障post_handle 函数可以读取到正常的输出结果。
Expand All @@ -301,7 +307,20 @@ def _async_copy_next_token_infos_to_pin_mem(self, next_token_ids: torch.Tensor,
key="next_token_logprobs",
gpu_tensor=next_token_logprobs,
)
return next_token_ids_cpu, next_token_logprobs_cpu

top_k_ids_cpu = None
top_k_logprobs_cpu = None
if top_k_ids is not None:
top_k_ids_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="top_k_ids",
gpu_tensor=top_k_ids,
)
top_k_logprobs_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="top_k_logprobs",
gpu_tensor=top_k_logprobs,
)

return next_token_ids_cpu, next_token_logprobs_cpu, top_k_ids_cpu, top_k_logprobs_cpu

def _try_read_new_reqs(self):
if self.is_multinode_tp:
Expand Down Expand Up @@ -646,19 +665,27 @@ def _post_handle(
run_reqs_update_packs: List[InferReqUpdatePack],
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
nixl_prefill_chuncked_handle_func: Optional[Callable[[InferReq, int, float, int], None]] = None,
top_k_ids: List[List[int]] = None,
top_k_logprobs: List[List[float]] = None,
):
"""
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
"""
for req_obj, next_token_id, next_token_logprob, pack in zip(
run_reqs, next_token_ids, next_token_logprobs, run_reqs_update_packs
if top_k_ids is None:
top_k_ids = [None] * len(run_reqs)
top_k_logprobs = [None] * len(run_reqs)

for req_obj, next_token_id, next_token_logprob, cur_top_k_ids, cur_top_k_logprobs, pack in zip(
run_reqs, next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs, run_reqs_update_packs
):
req_obj: InferReq = req_obj
pack: InferReqUpdatePack = pack
pack.handle(
next_token_id=next_token_id,
next_token_logprob=next_token_logprob,
top_k_ids=cur_top_k_ids,
top_k_logprobs=cur_top_k_logprobs,
eos_ids=self.eos_id,
extra_post_req_handle_func=extra_post_req_handle_func,
is_master_in_dp=self.is_master_in_dp,
Expand Down Expand Up @@ -724,7 +751,7 @@ def _sample_and_scatter_token(
assert len(run_reqs) == logits.shape[0]
mask_func(run_reqs, logits)

next_token_ids, next_token_logprobs = sample(logits, run_reqs, self.eos_id)
next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs = sample(logits, run_reqs, self.eos_id)
b_has_out = None
if is_prefill:
b_has_out = g_pin_mem_manager.gen_from_list(
Expand All @@ -743,10 +770,13 @@ def _sample_and_scatter_token(
next_token_ids=next_token_ids,
mask=b_has_out,
)
next_token_ids_cpu, next_token_logprobs_cpu = self._async_copy_next_token_infos_to_pin_mem(
next_token_ids, next_token_logprobs
)
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu
(
next_token_ids_cpu,
next_token_logprobs_cpu,
top_k_ids_cpu,
top_k_logprobs_cpu,
) = self._async_copy_next_token_infos_to_pin_mem(next_token_ids, next_token_logprobs, top_k_ids, top_k_logprobs)
return next_token_ids, next_token_ids_cpu, next_token_logprobs_cpu, top_k_ids_cpu, top_k_logprobs_cpu

def _dp_all_gather_prefill_and_decode_req_num(
self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]
Expand Down
Loading