Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias / self.tp_world_size_,
w1_scale=w1_scale,
w2_scale=w2_scale,
layout="interleaved",
alpha=self.alpha,
limit=self.limit,
Expand Down
38 changes: 19 additions & 19 deletions lightllm/common/fused_moe/grouped_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,13 +764,13 @@ def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -890,13 +890,13 @@ def inplace_fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -909,13 +909,13 @@ def inplace_fused_experts_impl(
hidden_states,
w1,
w2,
w1_bias,
w2_bias,
topk_weights,
topk_ids,
True,
use_fp8_w8a8,
use_int8_w8a16,
w1_bias,
w2_bias,
w1_scale,
w2_scale,
a1_scale,
Expand All @@ -930,13 +930,13 @@ def inplace_fused_experts_impl_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -960,13 +960,13 @@ def outplace_fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
# optional bias for w1 and w2
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -979,13 +979,13 @@ def outplace_fused_experts_impl(
hidden_states,
w1,
w2,
w1_bias,
w2_bias,
topk_weights,
topk_ids,
False,
use_fp8_w8a8,
use_int8_w8a16,
w1_bias,
w2_bias,
w1_scale,
w2_scale,
a1_scale,
Expand Down Expand Up @@ -1051,12 +1051,12 @@ def fused_experts(
hidden_states,
w1,
w2,
w1_bias,
w2_bias,
topk_weights,
topk_ids,
use_fp8_w8a8,
use_int8_w8a16,
w1_bias,
w2_bias,
w1_scale,
w2_scale,
a1_scale,
Expand All @@ -1071,12 +1071,12 @@ def fused_experts(
hidden_states,
w1,
w2,
w1_bias,
w2_bias,
topk_weights,
topk_ids,
use_fp8_w8a8,
use_int8_w8a16,
w1_bias,
w2_bias,
w1_scale,
w2_scale,
a1_scale,
Expand Down
28 changes: 27 additions & 1 deletion lightllm/server/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
parallel_tool_calls: Optional[bool] = True

# Additional parameters supported by LightLLM
do_sample: Optional[bool] = False
Expand All @@ -122,11 +123,36 @@ class FunctionResponse(BaseModel):
class ToolCall(BaseModel):
"""Tool call response."""

id: str
id: Optional[str] = None
index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse


class ChatCompletionMessageGenericParam(BaseModel):
role: Literal["system", "assistant", "tool", "function"]
content: Union[str, List[MessageContent], None] = Field(default=None)
tool_call_id: Optional[str] = None
name: Optional[str] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])

@field_validator("role", mode="before")
@classmethod
def _normalize_role(cls, v):
if isinstance(v, str):
v_lower = v.lower()
if v_lower not in {"system", "assistant", "tool", "function"}:
raise ValueError(
"'role' must be one of 'system', 'assistant', 'tool', or 'function' (case-insensitive)."
)
return v_lower
raise ValueError("'role' must be a string")


ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]


class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: Optional[int] = 0
Expand Down
86 changes: 74 additions & 12 deletions lightllm/server/api_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pickle
import uuid

from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser, ToolCallItem
from .build_prompt import build_prompt, init_tokenizer

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Expand Down Expand Up @@ -63,6 +63,52 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
return JSONResponse({"message": message}, status_code=status_code.value)


def _process_tool_call_id(
tool_call_parser,
call_item: ToolCallItem,
history_tool_calls_cnt: int,
) -> str:
"""Process for generating a new and unique `tool_call_id`"""
if tool_call_parser != "kimi_k2":
# A simple uuid is sufficient for all models except for Kimi-K2.
tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
return tool_call_id
else:
# Align with Kimi-K2 format: functions.{name}:{index}
# Kimi-K2 allows multiple tool_calls in one message;
# SGLang sets call_item.tool_index to the *local* position inside that message.
# Therefore, the index must be corrected by using
# `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
logger.debug(
f"Process tool call idx, parser: {tool_call_parser}, \
tool_call_id: {tool_call_id}, \
history_cnt: {history_tool_calls_cnt}"
)
return tool_call_id


def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
"""Counts the number of tool calls in the request's message history.

NOTE: This method is only useful for models that include self-increasing
history tool call idx in tool calls id, such as kimi-k2

Args:
request: The chat completion request object.

Returns:
The total number of tool calls in the history, or 0 if not applicable.
"""
messages = getattr(request, "messages", [])
idx = 0
for msg in messages:
if msg.role == "assistant":
tool_calls = getattr(msg, "tool_calls", None)
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx


async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
from .api_http import g_objs

Expand Down Expand Up @@ -180,26 +226,31 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req

if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
if finish_reason == "stop":
finish_reason = "function_call"
finish_reason = "tool_calls"
try:
# 为 tool_call_parser 提供默认值
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
parser = FunctionCallParser(tools, tool_parser)
full_normal_text, call_info_list = parser.parse_non_stream(text)
tool_calls = [
ToolCall(
id=str(call_info.tool_index),
function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
tool_calls = []
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
for call_info in call_info_list:
tool_id = _process_tool_call_id(tool_parser, call_info, history_tool_calls_cnt)
tool_calls.append(
ToolCall(
id=tool_id,
index=getattr(call_info, "tool_index", None),
function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
)
)
for call_info in call_info_list
]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)

if finish_reason == "tool_calls":
text = ""
chat_message = ChatMessage(role="assistant", content=text, tool_calls=tool_calls)
choice = ChatCompletionResponseChoice(
index=i,
Expand Down Expand Up @@ -261,6 +312,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
yield f"data: {chunk.model_dump_json()}\n\n"

# 2) if we found calls, we output them as separate chunk(s)
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
for call_item in calls:
# transform call_item -> FunctionResponse + ToolCall
if finish_reason == "stop":
Expand All @@ -278,17 +330,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
remaining_call = expected_call.replace(actual_call, "", 1)
call_item.parameters = remaining_call

if call_item.name:
# First chunk: include ID and function name
tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt)
function_name = call_item.name
else:
# Subsequent chunks: null ID and name for argument deltas
tool_call_id = None
function_name = None

tool_call = ToolCall(
id=str(call_item.tool_index),
id=tool_call_id,
index=getattr(call_item, "tool_index", None),
function=FunctionResponse(
name=call_item.name,
name=function_name,
arguments=call_item.parameters,
),
)
choice_data = ChatCompletionStreamResponseChoice(
index=0,
delta=DeltaMessage(role="assistant", tool_calls=[tool_call]),
finish_reason="function_call",
finish_reason="tool_calls",
)
chunk = ChatCompletionStreamResponse(
id=group_request_id,
Expand Down
4 changes: 3 additions & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class StartArgs:
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
eos_id: List[int] = field(default_factory=list)
tool_call_parser: Optional[str] = field(default=None, metadata={"choices": ["llama3", "qwen25", "mistral"]})
tool_call_parser: Optional[str] = field(
default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
)
running_max_req_size: int = field(default=1000)
tp: int = field(default=1)
dp: int = field(default=1)
Expand Down
Loading