diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py
index b1f6ce901..df72cc620 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/gpt_oss_fused_moe_weight_tp.py
@@ -138,10 +138,10 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
- w1_scale=w1_scale,
- w2_scale=w2_scale,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias / self.tp_world_size_,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
layout="interleaved",
alpha=self.alpha,
limit=self.limit,
diff --git a/lightllm/common/fused_moe/grouped_fused_moe.py b/lightllm/common/fused_moe/grouped_fused_moe.py
index 72a49970e..d4e86059b 100644
--- a/lightllm/common/fused_moe/grouped_fused_moe.py
+++ b/lightllm/common/fused_moe/grouped_fused_moe.py
@@ -764,13 +764,13 @@ def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- w1_bias: Optional[torch.Tensor],
- w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
@@ -890,13 +890,13 @@ def inplace_fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- # optional bias for w1 and w2
- w1_bias: Optional[torch.Tensor],
- w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
+ # optional bias for w1 and w2
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
@@ -909,13 +909,13 @@ def inplace_fused_experts_impl(
hidden_states,
w1,
w2,
- w1_bias,
- w2_bias,
topk_weights,
topk_ids,
True,
use_fp8_w8a8,
use_int8_w8a16,
+ w1_bias,
+ w2_bias,
w1_scale,
w2_scale,
a1_scale,
@@ -930,13 +930,13 @@ def inplace_fused_experts_impl_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- # optional bias for w1 and w2
- w1_bias: Optional[torch.Tensor],
- w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
+ # optional bias for w1 and w2
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
@@ -960,13 +960,13 @@ def outplace_fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- # optional bias for w1 and w2
- w1_bias: Optional[torch.Tensor],
- w2_bias: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
+ # optional bias for w1 and w2
+ w1_bias: Optional[torch.Tensor] = None,
+ w2_bias: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
@@ -979,13 +979,13 @@ def outplace_fused_experts_impl(
hidden_states,
w1,
w2,
- w1_bias,
- w2_bias,
topk_weights,
topk_ids,
False,
use_fp8_w8a8,
use_int8_w8a16,
+ w1_bias,
+ w2_bias,
w1_scale,
w2_scale,
a1_scale,
@@ -1051,12 +1051,12 @@ def fused_experts(
hidden_states,
w1,
w2,
- w1_bias,
- w2_bias,
topk_weights,
topk_ids,
use_fp8_w8a8,
use_int8_w8a16,
+ w1_bias,
+ w2_bias,
w1_scale,
w2_scale,
a1_scale,
@@ -1071,12 +1071,12 @@ def fused_experts(
hidden_states,
w1,
w2,
- w1_bias,
- w2_bias,
topk_weights,
topk_ids,
use_fp8_w8a8,
use_int8_w8a16,
+ w1_bias,
+ w2_bias,
w1_scale,
w2_scale,
a1_scale,
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 14615f72f..a89bcfb66 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -101,6 +101,7 @@ class ChatCompletionRequest(BaseModel):
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
+ parallel_tool_calls: Optional[bool] = True
# Additional parameters supported by LightLLM
do_sample: Optional[bool] = False
@@ -122,11 +123,36 @@ class FunctionResponse(BaseModel):
class ToolCall(BaseModel):
"""Tool call response."""
- id: str
+ id: Optional[str] = None
+ index: Optional[int] = None
type: Literal["function"] = "function"
function: FunctionResponse
+class ChatCompletionMessageGenericParam(BaseModel):
+ role: Literal["system", "assistant", "tool", "function"]
+ content: Union[str, List[MessageContent], None] = Field(default=None)
+ tool_call_id: Optional[str] = None
+ name: Optional[str] = None
+ reasoning_content: Optional[str] = None
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+
+ @field_validator("role", mode="before")
+ @classmethod
+ def _normalize_role(cls, v):
+ if isinstance(v, str):
+ v_lower = v.lower()
+ if v_lower not in {"system", "assistant", "tool", "function"}:
+ raise ValueError(
+ "'role' must be one of 'system', 'assistant', 'tool', or 'function' (case-insensitive)."
+ )
+ return v_lower
+ raise ValueError("'role' must be a string")
+
+
+ChatCompletionMessageParam = Union[ChatCompletionMessageGenericParam, Message]
+
+
class UsageInfo(BaseModel):
prompt_tokens: int = 0
completion_tokens: Optional[int] = 0
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index dd484103e..fc6733c7e 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -9,7 +9,7 @@
import pickle
import uuid
-from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
+from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser, ToolCallItem
from .build_prompt import build_prompt, init_tokenizer
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -63,6 +63,52 @@ def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse
return JSONResponse({"message": message}, status_code=status_code.value)
+def _process_tool_call_id(
+ tool_call_parser,
+ call_item: ToolCallItem,
+ history_tool_calls_cnt: int,
+) -> str:
+ """Process for generating a new and unique `tool_call_id`"""
+ if tool_call_parser != "kimi_k2":
+ # A simple uuid is sufficient for all models except for Kimi-K2.
+ tool_call_id = f"call_{uuid.uuid4().hex[:24]}"
+ return tool_call_id
+ else:
+ # Align with Kimi-K2 format: functions.{name}:{index}
+ # Kimi-K2 allows multiple tool_calls in one message;
+ # SGLang sets call_item.tool_index to the *local* position inside that message.
+ # Therefore, the index must be corrected by using
+ # `history_tool_calls_cnt + call_item.tool_index` to ensure globally unique and properly ordered.
+ tool_call_id = f"functions.{call_item.name}:{history_tool_calls_cnt+call_item.tool_index}"
+ logger.debug(
+ f"Process tool call idx, parser: {tool_call_parser}, \
+ tool_call_id: {tool_call_id}, \
+ history_cnt: {history_tool_calls_cnt}"
+ )
+ return tool_call_id
+
+
+def _get_history_tool_calls_cnt(request: ChatCompletionRequest) -> int:
+ """Counts the number of tool calls in the request's message history.
+
+ NOTE: This method is only useful for models that include self-increasing
+ history tool call idx in tool calls id, such as kimi-k2
+
+ Args:
+ request: The chat completion request object.
+
+ Returns:
+ The total number of tool calls in the history, or 0 if not applicable.
+ """
+ messages = getattr(request, "messages", [])
+ idx = 0
+ for msg in messages:
+ if msg.role == "assistant":
+ tool_calls = getattr(msg, "tool_calls", None)
+ idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
+ return idx
+
+
async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
from .api_http import g_objs
@@ -180,26 +226,31 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
if finish_reason == "stop":
- finish_reason = "function_call"
+ finish_reason = "tool_calls"
try:
# 为 tool_call_parser 提供默认值
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
parser = FunctionCallParser(tools, tool_parser)
full_normal_text, call_info_list = parser.parse_non_stream(text)
- tool_calls = [
- ToolCall(
- id=str(call_info.tool_index),
- function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
+ tool_calls = []
+ history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
+ for call_info in call_info_list:
+ tool_id = _process_tool_call_id(tool_parser, call_info, history_tool_calls_cnt)
+ tool_calls.append(
+ ToolCall(
+ id=tool_id,
+ index=getattr(call_info, "tool_index", None),
+ function=FunctionResponse(name=call_info.name, arguments=call_info.parameters),
+ )
)
- for call_info in call_info_list
- ]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
-
+ if finish_reason == "tool_calls":
+ text = ""
chat_message = ChatMessage(role="assistant", content=text, tool_calls=tool_calls)
choice = ChatCompletionResponseChoice(
index=i,
@@ -261,6 +312,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
yield f"data: {chunk.model_dump_json()}\n\n"
# 2) if we found calls, we output them as separate chunk(s)
+ history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
for call_item in calls:
# transform call_item -> FunctionResponse + ToolCall
if finish_reason == "stop":
@@ -278,17 +330,27 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
remaining_call = expected_call.replace(actual_call, "", 1)
call_item.parameters = remaining_call
+ if call_item.name:
+ # First chunk: include ID and function name
+ tool_call_id = _process_tool_call_id(tool_parser, call_item, history_tool_calls_cnt)
+ function_name = call_item.name
+ else:
+ # Subsequent chunks: null ID and name for argument deltas
+ tool_call_id = None
+ function_name = None
+
tool_call = ToolCall(
- id=str(call_item.tool_index),
+ id=tool_call_id,
+ index=getattr(call_item, "tool_index", None),
function=FunctionResponse(
- name=call_item.name,
+ name=function_name,
arguments=call_item.parameters,
),
)
choice_data = ChatCompletionStreamResponseChoice(
index=0,
delta=DeltaMessage(role="assistant", tool_calls=[tool_call]),
- finish_reason="function_call",
+ finish_reason="tool_calls",
)
chunk = ChatCompletionStreamResponse(
id=group_request_id,
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index ce5ef56fc..cd196978d 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -30,7 +30,9 @@ class StartArgs:
mem_fraction: float = field(default=0.9)
batch_max_tokens: Optional[int] = field(default=None)
eos_id: List[int] = field(default_factory=list)
- tool_call_parser: Optional[str] = field(default=None, metadata={"choices": ["llama3", "qwen25", "mistral"]})
+ tool_call_parser: Optional[str] = field(
+ default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
+ )
running_max_req_size: int = field(default=1000)
tp: int = field(default=1)
dp: int = field(default=1)
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index 22256a745..9f614195e 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -12,16 +12,21 @@
# limitations under the License.
import json
+import orjson
import logging
import re
from abc import ABC, abstractmethod
from json import JSONDecodeError, JSONDecoder
-from typing import Any, Dict, List, Optional, Tuple
+from json.decoder import WHITESPACE
+from typing import Any, Dict, List, Optional, Tuple, Type
import partial_json_parser
+from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
from pydantic import BaseModel, Field
+from .api_models import Tool
+
logger = logging.getLogger(__name__)
TOOLS_TAG_LIST = [
@@ -33,14 +38,6 @@
]
-class Function(BaseModel):
- """Function Tool Template."""
-
- description: Optional[str] = Field(default=None, examples=[None])
- name: Optional[str] = None
- parameters: Optional[object] = None
-
-
class ToolCallItem(BaseModel):
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
@@ -49,6 +46,14 @@ class ToolCallItem(BaseModel):
parameters: str # JSON string
+class StreamingParseResult:
+ """Result of streaming incremental parsing."""
+
+ def __init__(self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None):
+ self.normal_text = normal_text
+ self.calls = calls or []
+
+
def _find_common_prefix(s1: str, s2: str) -> str:
prefix = ""
min_length = min(len(s1), len(s2))
@@ -61,63 +66,89 @@ def _find_common_prefix(s1: str, s2: str) -> str:
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
+ """
+ Parse incomplete or partial JSON strings commonly encountered during streaming.
+
+ Args:
+ input_str (str): The potentially incomplete JSON string to parse.
+ flags (Allow): Bitwise flags controlling what types of partial data are allowed.
+ Common flags include:
+ - Allow.STR: Allow partial strings (e.g., '"hello wo' -> 'hello wo')
+ - Allow.OBJ: Allow partial objects (e.g., '{"key":' -> {'key': None})
+ - Allow.ARR: Allow partial arrays (e.g., '[1, 2,' -> [1, 2])
+ - Allow.ALL: Allow all types of partial data
+
+ Returns:
+ Tuple[Any, int]: A tuple containing:
+ - parsed_object: The Python object parsed from the JSON
+ - consumed_length: Number of characters consumed from input_str
+ """
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
- except JSONDecodeError as e:
- if "Extra data" in e.msg:
- dec = JSONDecoder()
- return dec.raw_decode(input_str)
+ except (JSONDecodeError, IndexError) as e:
+ msg = getattr(e, "msg", str(e))
+ if "Extra data" in msg or "pop from empty list" in msg:
+ start = WHITESPACE.match(input_str, 0).end()
+ obj, end = JSONDecoder().raw_decode(input_str, start)
+ return obj, end
raise
def _is_complete_json(input_str: str) -> bool:
try:
- json.loads(input_str)
+ orjson.loads(input_str)
return True
except JSONDecodeError:
return False
-class StreamingParseResult:
- """Result of streaming incremental parsing."""
-
- def __init__(self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None):
- self.normal_text = normal_text
- self.calls = calls or []
-
-
-class BaseFormatDetector:
+class BaseFormatDetector(ABC):
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
def __init__(self):
- # initialize properties used for state when parsing tool calls in
+ # Streaming state management
+ # Buffer for accumulating incomplete patterns that arrive across multiple streaming chunks
self._buffer = ""
- # streaming mode
+ # Stores complete tool call info (name and arguments) for each tool being parsed.
+ # Used by serving layer for completion handling when streaming ends.
+ # Format: [{"name": str, "arguments": dict}, ...]
self.prev_tool_call_arr: List[Dict] = []
+ # Index of currently streaming tool call. Starts at -1 (no active tool),
+ # increments as each tool completes. Tracks which tool's arguments are streaming.
self.current_tool_id: int = -1
+ # Flag for whether current tool's name has been sent to client.
+ # Tool names sent first with empty parameters, then arguments stream incrementally.
self.current_tool_name_sent: bool = False
- self.streamed_args_for_tool: List[str] = [] # map what has been streamed for each tool so far to a list
+ # Tracks raw JSON string content streamed to client for each tool's arguments.
+ # Critical for serving layer to calculate remaining content when streaming ends.
+ # Each index corresponds to a tool_id. Example: ['{"location": "San Francisco"', '{"temp": 72']
+ self.streamed_args_for_tool: List[str] = []
+
+ # Token configuration (override in subclasses)
self.bot_token = ""
self.eot_token = ""
+ self.tool_call_separator = ", "
+
+ def _get_tool_indices(self, tools: List[Tool]) -> Dict[str, int]:
+ """
+ Get a mapping of tool names to their indices in the tools list.
+
+ This utility method creates a dictionary mapping function names to their
+ indices in the tools list, which is commonly needed for tool validation
+ and ToolCallItem creation.
- def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
- tool_indices = {tool.function.name: i for i, tool in enumerate(tools) if tool.function.name}
+ Args:
+ tools: List of available tools
+
+ Returns:
+ Dictionary mapping tool names to their indices
+ """
+ return {tool.function.name: i for i, tool in enumerate(tools) if tool.function.name}
+
+ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
+ tool_indices = self._get_tool_indices(tools)
if not isinstance(action, list):
- name = action.get("name")
- if not name or name not in tool_indices:
- logger.warning(f"Model attempted to call undefined function: {name}")
- return []
-
- return [
- ToolCallItem(
- tool_index=tool_indices[name],
- name=name,
- parameters=json.dumps(
- action.get("parameters") or action.get("arguments", {}),
- ensure_ascii=False,
- ),
- )
- ]
+ action = [action]
results = []
for act in action:
@@ -125,7 +156,7 @@ def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallIt
if name and name in tool_indices:
results.append(
ToolCallItem(
- tool_index=tool_indices[name],
+ tool_index=-1, # Caller should update this based on the actual tools array called
name=name,
parameters=json.dumps(
act.get("parameters") or act.get("arguments", {}),
@@ -133,108 +164,137 @@ def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallIt
),
)
)
+ else:
+ logger.warning(f"Model attempted to call undefined function: {name}")
return results
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
+ @abstractmethod
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
- action = json.loads(text)
- return self.parse_base_json(action, tools)
+ action = orjson.loads(text)
+ return StreamingParseResult(calls=self.parse_base_json(action, tools))
+
+ def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
+ """
+ Check if buffer ends with a partial bot_token.
+ Return the length of the partial bot_token.
+
+ For some format, the bot_token is not a token in model's vocabulary, such as
+ `[TOOL_CALLS] [` in Mistral.
+ """
+ for i in range(1, min(len(buffer) + 1, len(bot_token))):
+ if bot_token.startswith(buffer[-i:]):
+ return i
+ return 0
- def parse_streaming_increment(self, new_text: str, tools: List[Function]) -> StreamingParseResult:
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
"""
Streaming incremental parsing with tool validation.
+
+ This base implementation works best with formats where:
+ 1. bot_token is followed immediately by JSON (e.g., bot_token + JSON_array)
+ 2. JSON can be parsed incrementally using partial_json_loads
+ 3. Multiple tool calls are separated by "; " or ", "
+
+ Examples of incompatible formats (need custom implementation, may reuse some logic from this class):
+ - Each tool call is wrapped in a separate block: See Qwen25Detector
+ - Multiple separate blocks: [TOOL_CALLS] [...] \n [TOOL_CALLS] [...]
+ - Tool call is Pythonic style
+
+ For incompatible formats, detectors should override this method with custom logic.
"""
# Append new text to buffer
self._buffer += new_text
current_text = self._buffer
- if not (self.bot_token in current_text or current_text.startswith("{")):
- self._buffer = ""
- if self.eot_token in new_text:
- new_text = new_text.replace(self.eot_token, "")
- return StreamingParseResult(normal_text=new_text)
+
+ # The current_text has tool_call if it is the start of a new tool call sequence
+ # or it is the start of a new tool call after a tool call separator, when there is a previous tool call
+ if not (
+ self.has_tool_call(current_text)
+ or (self.current_tool_id > 0 and current_text.startswith(self.tool_call_separator))
+ ):
+ # Only clear buffer if we're sure no tool call is starting
+ if not self._ends_with_partial_token(self._buffer, self.bot_token):
+ normal_text = self._buffer
+ self._buffer = ""
+ if self.eot_token in normal_text:
+ normal_text = normal_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=normal_text)
+ else:
+ # Might be partial bot_token, keep buffering
+ return StreamingParseResult()
# Build tool indices if not already built
if not hasattr(self, "_tool_indices"):
- self._tool_indices = {
- tool.function.name: i for i, tool in enumerate(tools) if tool.function and tool.function.name
- }
+ self._tool_indices = self._get_tool_indices(tools)
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
+
try:
- tool_call_arr = []
- is_complete = []
try:
- start_idx = len(self.bot_token) if current_text.startswith(self.bot_token) else 0
- while start_idx < len(current_text):
- (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
- is_complete.append(_is_complete_json(current_text[start_idx : start_idx + end_idx]))
- start_idx += end_idx + len("; ")
-
- # Validate tool name if present
- if "name" in obj and obj["name"] not in self._tool_indices:
- # Invalid tool name - reset state
- self._buffer = ""
- self.current_tool_id = -1
- self.current_tool_name_sent = False
- if self.streamed_args_for_tool:
- self.streamed_args_for_tool.pop()
- return StreamingParseResult()
+ tool_call_pos = current_text.find(self.bot_token)
+ if tool_call_pos != -1:
+ start_idx = tool_call_pos + len(self.bot_token)
+ elif self.current_tool_id > 0 and current_text.startswith(self.tool_call_separator):
+ start_idx = len(self.tool_call_separator)
+ else:
+ start_idx = 0
- # Handle parameters/arguments consistency
- if "parameters" in obj:
- assert "arguments" not in obj, "model generated both parameters and arguments"
- obj["arguments"] = obj["parameters"]
- tool_call_arr.append(obj)
+ if start_idx >= len(current_text):
+ return StreamingParseResult()
- except partial_json_parser.core.exceptions.MalformedJSON:
- return StreamingParseResult()
+ (obj, end_idx) = _partial_json_loads(current_text[start_idx:], flags)
- if len(tool_call_arr) == 0:
- return StreamingParseResult()
+ is_current_complete = _is_complete_json(current_text[start_idx : start_idx + end_idx])
- current_tool_call: Dict = tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
+ # Validate tool name if present
+ if "name" in obj and obj["name"] not in self._tool_indices:
+ # Invalid tool name - reset state
+ self._buffer = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ if self.streamed_args_for_tool:
+ self.streamed_args_for_tool.pop()
+ return StreamingParseResult()
- # Handle new tool in array
- if len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1:
- if self.current_tool_id >= 0:
- cur_arguments = current_tool_call.get("arguments")
- if cur_arguments:
- cur_args_json = json.dumps(cur_arguments)
- sent = len(self.streamed_args_for_tool[self.current_tool_id])
- argument_diff = cur_args_json[sent:]
+ # Handle parameters/arguments consistency
+ # NOTE: we assume here that the obj is always partial of a single tool call
+ if "parameters" in obj:
+ assert "arguments" not in obj, "model generated both parameters and arguments"
+ obj["arguments"] = obj["parameters"]
- res = StreamingParseResult(
- calls=[
- ToolCallItem(
- tool_index=self.current_tool_id,
- name="",
- parameters=argument_diff,
- )
- ],
- )
- self.streamed_args_for_tool[self.current_tool_id] += argument_diff
- else:
- res = StreamingParseResult()
- else:
- res = StreamingParseResult()
+ current_tool_call = obj
- self.current_tool_id = len(tool_call_arr) - 1
- self.current_tool_name_sent = False
- self.streamed_args_for_tool.append("")
- return res
+ except MalformedJSON:
+ return StreamingParseResult()
+
+ if not current_tool_call:
+ return StreamingParseResult()
- # Handle tool name
- elif not self.current_tool_name_sent:
+ # Case 1: Handle tool name streaming
+ # This happens when we encounter a tool but haven't sent its name yet
+ if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
+
if function_name and function_name in self._tool_indices:
+ # If this is a new tool (current_tool_id was -1), initialize it
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.streamed_args_for_tool.append("")
+ # If this is a subsequent tool, ensure streamed_args_for_tool is large enough
+ elif self.current_tool_id >= len(self.streamed_args_for_tool):
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ # Send the tool name with empty parameters
res = StreamingParseResult(
calls=[
ToolCallItem(
- tool_index=self._tool_indices[function_name],
+ tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
@@ -244,44 +304,66 @@ def parse_streaming_increment(self, new_text: str, tools: List[Function]) -> Str
else:
res = StreamingParseResult()
- # Handle streaming arguments
+ # Case 2: Handle streaming arguments
+ # This happens when we've already sent the tool name and now need to stream arguments incrementally
else:
cur_arguments = current_tool_call.get("arguments")
res = StreamingParseResult()
if cur_arguments:
+ # Calculate how much of the arguments we've already streamed
sent = len(self.streamed_args_for_tool[self.current_tool_id])
cur_args_json = json.dumps(cur_arguments)
- prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
+ prev_arguments = None
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
argument_diff = None
- if is_complete[self.current_tool_id]:
+
+ # If the current tool's JSON is complete, send all remaining arguments
+ if is_current_complete:
argument_diff = cur_args_json[sent:]
- self._buffer = ""
- self.prev_tool_call_arr[self.current_tool_id].clear()
+ completing_tool_id = self.current_tool_id # Save the ID of the tool that's completing
+
+ # Only remove the processed portion, keep unprocessed content
+ self._buffer = current_text[start_idx + end_idx :]
+
+ if self.current_tool_id < len(self.prev_tool_call_arr):
+ self.prev_tool_call_arr[self.current_tool_id].clear()
self.current_tool_name_sent = False
self.streamed_args_for_tool[self.current_tool_id] = ""
+ self.current_tool_id += 1
+ # If the tool is still being parsed, send incremental changes
elif prev_arguments:
prev_args_json = json.dumps(prev_arguments)
if cur_args_json != prev_args_json:
prefix = _find_common_prefix(prev_args_json, cur_args_json)
argument_diff = prefix[sent:]
+ # Send the argument diff if there's something new
if argument_diff is not None:
+ # Use the correct tool_index: completing_tool_id for completed tools,
+ # current_tool_id for ongoing
+ tool_index_to_use = completing_tool_id if is_current_complete else self.current_tool_id
res = StreamingParseResult(
calls=[
ToolCallItem(
- tool_index=self.current_tool_id,
- name="",
+ tool_index=tool_index_to_use,
parameters=argument_diff,
)
],
)
- if not is_complete[self.current_tool_id]:
+ if not is_current_complete:
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
- self.prev_tool_call_arr = tool_call_arr
+ # Update prev_tool_call_arr with current state
+ if self.current_tool_id >= 0:
+ # Ensure prev_tool_call_arr is large enough
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
+
return res
except Exception as e:
@@ -291,9 +373,19 @@ def parse_streaming_increment(self, new_text: str, tools: List[Function]) -> Str
class Qwen25Detector(BaseFormatDetector):
"""
- Detector for Qwen 2.5 models.
- Assumes function call format:
- {"name":"xxx", "arguments":{...}}
+ Detector for Qwen 2.5 and Qwen 3 model function call format.
+
+ Format Structure:
+ ```
+ \n{"name":"func1", "arguments":{...}}\n
+ \n\n{"name":"func2", "arguments":{...}}\n
+ ```
+
+ Key Components:
+ - Tool Call Tags: `` and `` wrap each individual call
+ - Function Call Object: JSON object with "name" and "arguments" fields
+
+ Reference: https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct?chat_template=default
"""
def __init__(self):
@@ -301,10 +393,16 @@ def __init__(self):
Initializes the detector with necessary state variables.
"""
super().__init__()
- self.bot_token = ""
- self.eot_token = ""
+ self.bot_token = "\n"
+ self.eot_token = "\n"
+ self.tool_call_separator = "\n"
+ self._normal_text_buffer = "" # Buffer for handling partial end tokens
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Qwen 2.5 format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -312,22 +410,70 @@ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallIte
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
- if "" not in text:
- return []
- pattern = r"(.*?)"
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
+ # Find all \n...\n blocks
+ pattern = rf"{re.escape(self.bot_token)}(.*?){re.escape(self.eot_token)}"
match_result_list = re.findall(pattern, text, re.DOTALL)
calls = []
for match_result in match_result_list:
- match_result = json.loads(match_result)
- calls.extend(self.parse_base_json(match_result, tools))
- return calls
+ try:
+ parsed_call = json.loads(match_result.strip())
+ calls.extend(self.parse_base_json(parsed_call, tools))
+ except json.JSONDecodeError as e:
+ logger.warning(f"Failed to parse JSON part: {match_result}, JSON parse error: {str(e)}")
+ continue
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing for Qwen 2.5 tool calls.
+ Uses base class implementation with buffering to handle partial end tokens.
+ """
+ result = super().parse_streaming_increment(new_text, tools)
+
+ # Handle partial end tokens that are streamed character by character
+ if result.normal_text:
+ self._normal_text_buffer += result.normal_text
+
+ # Check if buffer contains complete end token (without leading newline)
+ end_token_without_newline = self.eot_token[1:] # ""
+ if end_token_without_newline in self._normal_text_buffer:
+ cleaned_text = self._normal_text_buffer.replace(end_token_without_newline, "")
+ self._normal_text_buffer = ""
+ result.normal_text = cleaned_text
+ else:
+ # Check if buffer might contain partial end token at the end
+ partial_match_len = self._ends_with_partial_token(self._normal_text_buffer, end_token_without_newline)
+
+ if partial_match_len:
+ # Keep potential partial match in buffer, return the rest
+ result.normal_text = self._normal_text_buffer[:-partial_match_len]
+ self._normal_text_buffer = self._normal_text_buffer[-partial_match_len:]
+ else:
+ # No partial match, return all buffered text
+ result.normal_text = self._normal_text_buffer
+ self._normal_text_buffer = ""
+
+ return result
class MistralDetector(BaseFormatDetector):
"""
- Detector for Mistral models.
- Assumes function call format:
- <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
+ Detector for Mistral model function call format.
+
+ The Mistral format uses a simple bracket-delimited structure with JSON arrays
+ containing function call objects.
+
+ Format Structure:
+ ```
+ [TOOL_CALLS] [{"name": "function_name", "arguments": {json_args}}, ...]
+ ```
+
+ Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3?chat_template=default
"""
def __init__(self):
@@ -336,27 +482,15 @@ def __init__(self):
"""
super().__init__()
self.bot_token = "[TOOL_CALLS] ["
+ self.eot_token = "]"
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
+ self.tool_call_separator = ", "
- def _clean_text(self, text: str) -> str:
- """
- clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
- for example,
- text = '[TOOL_CALLS] [{"name": "get_current_weather",
- "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\n
- Today\'s weather in Boston is :{function call result} (in Fahrenheit)\n\n
- If you prefer Celsius, please let me know.'
- return '[TOOL_CALLS] [{"name": "get_current_weather",
- "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
- The key pattern is [TOOL_CALLS] [...]
- """
- find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
- if len(find_results) > 0:
- return find_results[0]
- else:
- return ""
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Mistral format tool call."""
+ return self.bot_token in text
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
One-time parsing: Detects and parses tool calls in the provided text.
@@ -364,144 +498,775 @@ def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallIte
:param tools: List of available tools.
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
- text = self._clean_text(text)
- tool_content = text.replace("[TOOL_CALLS]", "").strip()
- raw_tool_calls = self.tool_call_regex.findall(tool_content)
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
+ # Extract the JSON array part from [TOOL_CALLS] [...]
+ # Use bracket counting to properly handle nested brackets in JSON content
+ json_array_str = self._extract_json_array(text)
+ if not json_array_str:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+
calls = []
- if len(raw_tool_calls) > 0:
- raw_tool_call = raw_tool_calls[0]
- function_call_arr = json.loads(raw_tool_call)
- for match_result in function_call_arr:
- calls.extend(self.parse_base_json(match_result, tools))
- return calls
+ try:
+ function_call_arr = json.loads(json_array_str)
+ # Handle both single object and array of objects
+ if not isinstance(function_call_arr, list):
+ function_call_arr = [function_call_arr]
+ calls = self.parse_base_json(function_call_arr, tools)
+ except json.JSONDecodeError as e:
+ logger.warning(f"Failed to parse JSON part: {json_array_str}, JSON parse error: {str(e)}")
+
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+
+ def _extract_json_array(self, text: str) -> str:
+ """
+ Extract the JSON array part using bracket counting to handle nested brackets.
+
+ :param text: The complete text containing [TOOL_CALLS] [...]
+ :return: The JSON array string or None if not found
+ """
+ start_idx = text.find(self.bot_token)
+ if start_idx == -1:
+ return None
+
+ # Start from the opening bracket after [TOOL_CALLS]
+ json_start = start_idx + len(self.bot_token) - 1 # -1 to include the opening bracket
+ bracket_count = 0
+ in_string = False
+ escape_next = False
+
+ for i in range(json_start, len(text)):
+ char = text[i]
+
+ if escape_next:
+ escape_next = False
+ continue
+
+ if char == "\\":
+ escape_next = True
+ continue
+
+ if char == '"' and not escape_next:
+ in_string = not in_string
+ continue
+
+ if not in_string:
+ if char == "[":
+ bracket_count += 1
+ elif char == "]":
+ bracket_count -= 1
+ if bracket_count == 0:
+ return text[json_start : i + 1]
+
+ return None
class Llama32Detector(BaseFormatDetector):
"""
- Detector for Llama 3.2 models.
- Assumes function call format:
- <|python_tag|>{"name":"xxx", "arguments":{...}}
+ Detector for Llama 3.2 models with json tool call format.
+
+ Format Structure:
+ ```
+ {"name":"xxx", "arguments":{...}}
+ ```
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
-
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
+ # NOTE: technically Llama3.2 doesn't support well with parallel tool calls
+ # They need specific prompt engineering to support parallel tool calls
+ # Here we use ';' as the separator, which might have compatibility issues
+ # if users define to use a different separator in their prompt
+ self.tool_call_separator = ";"
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a Llama 3.2 format tool call."""
+ # depending on the prompt format the Llama model may or may not
+ # prefix the output with the <|python_tag|> token
+ return "<|python_tag|>" in text or text.startswith("{")
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
- if "<|python_tag|>" not in text:
- return []
-
- _, action_text = text.split("<|python_tag|>")
+ if "<|python_tag|>" not in text and not text.startswith("{"):
+ return StreamingParseResult(normal_text=text, calls=[])
- # Split by semicolon and process each part
- json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
+ if "<|python_tag|>" in text:
+ normal_text, action_text = text.split("<|python_tag|>", maxsplit=1)
+ else:
+ normal_text, action_text = "", text
+ decoder = json.JSONDecoder()
+ idx = 0
+ safe_idx = idx # the index of the last valid JSON object
all_actions = []
- for part in json_parts:
+ action_text_len = len(action_text)
+ while idx < action_text_len:
try:
- # Parse each individual JSON object
- action = json.loads(part)
- all_actions.append(action)
+ obj, end = decoder.raw_decode(action_text[idx:])
+ all_actions.append(obj)
+ idx += end + len(self.tool_call_separator)
+ safe_idx = idx
except json.JSONDecodeError as e:
- logger.warning(f"Failed to parse JSON part: {part}")
- logger.warning(f"JSON parse error: {str(e)}")
+ # Find where next `{"name"` appears and try again
+ logger.warning(f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}")
+ next_obj_start = action_text.find('{"name":', idx + 1)
+ if next_obj_start == -1:
+ break
+ idx = next_obj_start
continue
# Only process if we found valid JSON objects
- if all_actions:
- return self.parse_base_json(all_actions, tools)
+ calls = self.parse_base_json(all_actions, tools) if all_actions else []
+ # Use safe_idx to avoid idx containing the last part of an invalid JSON object
+ trailing_text = action_text[safe_idx:].strip() if safe_idx < action_text_len else ""
+ return StreamingParseResult(normal_text=normal_text + trailing_text, calls=calls)
+
+
+class KimiK2Detector(BaseFormatDetector):
+ """
+ Detector for Kimi K2 model function call format.
+
+ Format Structure:
+ ```
+ <|tool_calls_section_begin|>
+ <|tool_call_begin|>functions.{func_name}:{index}<|tool_call_argument_begin|>{json_args}<|tool_call_end|>
+ <|tool_calls_section_end|>
+ ```
+
+ Reference: https://huggingface.co/moonshotai/Kimi-K2-Instruct/blob/main/docs/tool_call_guidance.md
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.bot_token: str = "<|tool_calls_section_begin|>"
+ self.eot_token: str = "<|tool_calls_section_end|>"
+
+ self.tool_call_start_token: str = "<|tool_call_begin|>"
+ self.tool_call_end_token: str = "<|tool_call_end|>"
+
+ self.tool_call_regex = re.compile(
+ r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*?\})\s*<\|tool_call_end\|>" # noqa
+ )
+
+ self.stream_tool_call_portion_regex = re.compile(
+ r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*)" # noqa
+ )
+
+ self._last_arguments = ""
+
+ # Robust parser for ids like "functions.search:0" or fallback "search:0"
+ self.tool_call_id_regex = re.compile(r"^(?:functions\.)?(?P[\w\.]+):(?P\d+)$")
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a KimiK2 format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=text, calls=[])
+ try:
+ # there are two possible captures - between tags, or between a
+ # tag and end-of-string so the result of
+ # findall is an array of tuples where one is a function call and
+ # the other is None
+ function_call_tuples = self.tool_call_regex.findall(text)
+
+ logger.debug("function_call_tuples: %s", function_call_tuples)
+
+ tool_calls = []
+ for match in function_call_tuples:
+ function_id, function_args = match
+ m = self.tool_call_id_regex.match(function_id)
+ if not m:
+ logger.warning("Unexpected tool_call_id format: %s", function_id)
+ continue
+ function_name = m.group("name")
+ function_idx = int(m.group("index"))
+
+ logger.info(f"function_name {function_name}")
+
+ tool_calls.append(
+ ToolCallItem(
+ tool_index=function_idx,
+ name=function_name,
+ parameters=function_args,
+ )
+ )
- return []
+ content = text[: text.find(self.bot_token)]
+ return StreamingParseResult(normal_text=content, calls=tool_calls)
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
-class MultiFormatParser:
- def __init__(self, detectors: List[BaseFormatDetector]):
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
"""
- :param detectors: A series of available Detector instances passed in
+ Streaming incremental parsing tool calls for KimiK2 format.
"""
- self.detectors = detectors
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # Check if we have a tool call (either the start token or individual tool call)
+ has_tool_call = self.bot_token in current_text or self.tool_call_start_token in current_text
+
+ if not has_tool_call:
+ self._buffer = ""
+ for e_token in [self.eot_token, self.tool_call_end_token]:
+ if e_token in new_text:
+ new_text = new_text.replace(e_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: list[ToolCallItem] = []
+ try:
+ match = self.stream_tool_call_portion_regex.search(current_text)
+ if match:
+ function_id = match.group("tool_call_id")
+ function_args = match.group("function_arguments")
+
+ m = self.tool_call_id_regex.match(function_id)
+ if not m:
+ logger.warning("Unexpected tool_call_id format: %s", function_id)
+ return StreamingParseResult(normal_text="", calls=calls)
+ function_name = m.group("name")
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=function_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ # Store the tool call info for serving layer completions endpoint
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": function_name,
+ "arguments": {},
+ }
+ else:
+ argument_diff = (
+ function_args[len(self._last_arguments) :]
+ if function_args.startswith(self._last_arguments)
+ else function_args
+ )
- def parse_once(self, text: str, tools: List[Function]):
+ parsed_args_diff = argument_diff.split("<|tool_call_end|>", 1)[0]
+
+ if parsed_args_diff:
+
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=parsed_args_diff,
+ )
+ )
+ self._last_arguments += argument_diff
+ self.streamed_args_for_tool[self.current_tool_id] += parsed_args_diff
+
+ parsed_args = function_args.split("<|tool_call_end|>", 1)[0]
+ if _is_complete_json(parsed_args):
+ try:
+ parsed_args = json.loads(parsed_args)
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = parsed_args
+ except json.JSONDecodeError:
+ pass
+
+ # Find the end of the current tool call and remove only that part from buffer
+ tool_call_end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"
+ match = re.search(tool_call_end_pattern, current_text, re.DOTALL)
+ if match:
+ # Remove the completed tool call from buffer, keep any remaining content
+ self._buffer = current_text[match.end() :]
+ else:
+ self._buffer = ""
+
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text=current_text)
+
+
+class DeepSeekV31Detector(BaseFormatDetector):
+ """
+ Detector for DeepSeek V3 model function call format.
+
+ The DeepSeek V3 format uses special Unicode tokens to delimit function calls
+ with JSON code blocks for arguments.
+
+ Format Structure:
+ ```
+ <|tool▁calls▁begin|><|tool▁call▁begin|>{function_name}<|tool▁sep|>{json_arguments}
+ <|tool▁calls▁end|><|end▁of▁sentence|>
+ ```
+ Examples:
+ ```
+ <|tool▁calls▁begin|><|tool▁call▁begin|>get_current_weather<|tool▁sep|>
+ {"location": "Tokyo"}<|tool▁call▁end|><|tool▁call▁begin|>get_current_weather
+ <|tool▁sep|>{"location": "Paris"}<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
+ ```
+
+ Key Components:
+ - Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
+ - Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
+ - Function Declaration: `<|tool▁call▁begin|>{function_name}<|tool▁sep|>`
+ - Arguments: JSON code block between `<|tool▁sep|>` and `<|tool▁call▁end|>`
+ - Supports multiple tool calls
+
+ Reference: https://www.modelscope.cn/models/deepseek-ai/DeepSeek-V3.1
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "<|tool▁calls▁begin|>"
+ self.eot_token = "<|tool▁calls▁end|>"
+ self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
+ self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)<|tool▁call▁end|>"
+ self._last_arguments = ""
+ self.current_tool_id = -1
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a deepseek format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""
- One-time parsing: Loop through detectors until there are no new matches or text is exhausted
- Return: (final_text, all_calls)
- - final_text: The remaining text after parsing that was not
- consumed by any Detector (can be treated as normal text)
- - all_calls: All calls parsed by the Detectors
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
"""
- final_calls = []
- final_normal_text = text
- for detector in self.detectors:
- tool_call_list = detector.detect_and_parse(text, tools)
- if len(tool_call_list) > 0: # parsed successfully
- final_calls = tool_call_list
- break
-
- # leftover_text is the normal text not consumed by any Detector
- return final_normal_text, final_calls
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
+ calls = []
+ try:
+ for match_result in match_result_list:
+ # Get function name
+ func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
+ func_name = func_detail.group(1)
+ func_args = func_detail.group(2)
+ func_args = json.loads(func_args)
+ # construct match_result for parse_base_json
+ match_result = {"name": func_name, "parameters": func_args}
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
- def parse_streaming_increment(self, new_text: str, tools: List[Function]):
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
"""
- Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
- and merge their produced normal_text/calls to return.
- (The logic here can be "priority-based" or "parallel parsing" based on your needs)
+ Streaming incremental parsing tool calls for DeepSeekV3 format.
"""
- final_normal_text = ""
- final_calls = []
+ self._buffer += new_text
+ current_text = self._buffer
- for detector in self.detectors:
- sp_result = detector.parse_streaming_increment(new_text, tools)
- # Merge normal_text and calls
- # If one sp_result contains result call, this should be a successful parse
- # If one sp_result only contains normal_text, this can either be a successful
- # parse or it is not using the desired parsing tool.
- if sp_result.normal_text:
- final_normal_text = sp_result.normal_text
- if sp_result.calls:
- final_calls.extend(sp_result.calls)
- final_normal_text = sp_result.normal_text
- break
+ # Check if we have a tool call (either the start token or individual tool call)
+ has_tool_call = self.bot_token in current_text or "<|tool▁call▁begin|>" in current_text
- return final_normal_text, final_calls
+ if not has_tool_call:
+ self._buffer = ""
+ for e_token in [self.eot_token, "<|tool▁call▁end|>"]:
+ if e_token in new_text:
+ new_text = new_text.replace(e_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: list[ToolCallItem] = []
+ try:
+ partial_match = re.search(
+ pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)<|tool▁call▁end|>",
+ string=current_text,
+ flags=re.DOTALL,
+ )
+ if partial_match:
+ func_name = partial_match.group(1).strip()
+ func_args_raw = partial_match.group(2).strip()
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ # Store the tool call info for serving layer completions endpoint
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+ else:
+ argument_diff = (
+ func_args_raw[len(self._last_arguments) :]
+ if func_args_raw.startswith(self._last_arguments)
+ else func_args_raw
+ )
+
+ if argument_diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=argument_diff,
+ )
+ )
+ self._last_arguments += argument_diff
+ self.streamed_args_for_tool[self.current_tool_id] += argument_diff
+
+ if _is_complete_json(func_args_raw):
+ # Update the stored arguments
+ try:
+ parsed_args = json.loads(func_args_raw)
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = parsed_args
+ except json.JSONDecodeError:
+ pass
+
+ # Find the end of the current tool call and remove only that part from buffer
+ tool_call_end_pattern = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
+ match = re.search(tool_call_end_pattern, current_text, re.DOTALL)
+ if match:
+ # Remove the completed tool call from buffer, keep any remaining content
+ self._buffer = current_text[match.end() :]
+ else:
+ self._buffer = ""
+
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text=current_text)
+
+
+class DeepSeekV3Detector(BaseFormatDetector):
+ """
+ Detector for DeepSeek V3 model function call format.
+
+ The DeepSeek V3 format uses special Unicode tokens to delimit function calls
+ with JSON code blocks for arguments.
+
+ Format Structure:
+ ```
+ <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>{function_name}
+ \n```json\n{json_arguments}\n```<|tool▁calls▁end|><|end▁of▁sentence|>
+ ```
+ Examples:
+ ```
+ <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n
+ ```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>
+ function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```
+ <|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
+ ```
+
+ Key Components:
+ - Tool Calls Section: Wrapped between `<|tool▁calls▁begin|>` and `<|tool▁calls▁end|>`
+ - Individual Tool Call: Wrapped between `<|tool▁call▁begin|>` and `<|tool▁call▁end|>`
+ - Function Declaration: `function<|tool▁sep|>{function_name}`
+ - Arguments: JSON code block between ````json` and ````
+ - Supports multiple tool calls
+
+ Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324?chat_template=default
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = "<|tool▁calls▁begin|>"
+ self.eot_token = "<|tool▁calls▁end|>"
+ self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
+ self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
+ self._last_arguments = ""
+ self.current_tool_id = -1
+
+ def has_tool_call(self, text: str) -> bool:
+ """Check if the text contains a deepseek format tool call."""
+ return self.bot_token in text
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ One-time parsing: Detects and parses tool calls in the provided text.
+
+ :param text: The complete text to parse.
+ :param tools: List of available tools.
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
+ """
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+ if self.bot_token not in text:
+ return StreamingParseResult(normal_text=normal_text, calls=[])
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
+ calls = []
+ try:
+ for match_result in match_result_list:
+ # Get function name
+ func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
+ func_name = func_detail.group(2)
+ func_args = func_detail.group(3)
+ func_args = json.loads(func_args)
+ # construct match_result for parse_base_json
+ match_result = {"name": func_name, "parameters": func_args}
+ calls.extend(self.parse_base_json(match_result, tools))
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
+ except Exception as e:
+ logger.error(f"Error in detect_and_parse: {e}")
+ # return the normal text if parsing fails
+ return StreamingParseResult(normal_text=text)
+
+ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
+ """
+ Streaming incremental parsing tool calls for DeepSeekV3 format.
+ """
+ self._buffer += new_text
+ current_text = self._buffer
+
+ # Check if we have a tool call (either the start token or individual tool call)
+ has_tool_call = self.bot_token in current_text or "<|tool▁call▁begin|>" in current_text
+
+ if not has_tool_call:
+ self._buffer = ""
+ for e_token in [self.eot_token, "```", "<|tool▁call▁end|>"]:
+ if e_token in new_text:
+ new_text = new_text.replace(e_token, "")
+ return StreamingParseResult(normal_text=new_text)
+
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls: list[ToolCallItem] = []
+ try:
+ partial_match = re.search(
+ pattern=r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```.*",
+ string=current_text,
+ flags=re.DOTALL,
+ )
+ if partial_match:
+ func_name = partial_match.group(2).strip()
+ func_args_raw = partial_match.group(3).strip()
+
+ # Initialize state if this is the first tool call
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ # Ensure we have enough entries in our tracking arrays
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ if not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ # Store the tool call info for serving layer completions endpoint
+ self.prev_tool_call_arr[self.current_tool_id] = {
+ "name": func_name,
+ "arguments": {},
+ }
+ else:
+ argument_diff = (
+ func_args_raw[len(self._last_arguments) :]
+ if func_args_raw.startswith(self._last_arguments)
+ else func_args_raw
+ )
+
+ if argument_diff:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=None,
+ parameters=argument_diff,
+ )
+ )
+ self._last_arguments += argument_diff
+ self.streamed_args_for_tool[self.current_tool_id] += argument_diff
+
+ if _is_complete_json(func_args_raw):
+ # Update the stored arguments
+ try:
+ parsed_args = json.loads(func_args_raw)
+ self.prev_tool_call_arr[self.current_tool_id]["arguments"] = parsed_args
+ except json.JSONDecodeError:
+ pass
+
+ # Find the end of the current tool call and remove only that part from buffer
+ tool_call_end_pattern = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
+ match = re.search(tool_call_end_pattern, current_text, re.DOTALL)
+ if match:
+ # Remove the completed tool call from buffer, keep any remaining content
+ self._buffer = current_text[match.end() :]
+ else:
+ self._buffer = ""
+
+ result = StreamingParseResult(normal_text="", calls=calls)
+ self.current_tool_id += 1
+ self._last_arguments = ""
+ self.current_tool_name_sent = False
+ return result
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+ except Exception as e:
+ logger.error(f"Error in parse_streaming_increment: {e}")
+ return StreamingParseResult(normal_text=current_text)
class FunctionCallParser:
"""
- In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
+ Parser for function/tool calls in model outputs.
+
+ This class handles both streaming and non-streaming parsing of function calls using a detector.
+ In streaming scenarios, each time new_text is received, it calls detector.parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer (or SSE).
"""
- ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
+ ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
+ "deepseekv3": DeepSeekV3Detector,
+ "deepseekv31": DeepSeekV31Detector,
+ "kimi_k2": KimiK2Detector,
"llama3": Llama32Detector,
- "qwen25": Qwen25Detector,
"mistral": MistralDetector,
+ "qwen": Qwen25Detector,
+ "qwen25": Qwen25Detector,
}
- def __init__(self, tools: List[Function], tool_call_parser: str = None):
- detectors = []
- if tool_call_parser:
- detector_class = self.ToolCallParserEnum.get(tool_call_parser)
- if detector_class:
- detectors.append(detector_class())
- else:
- raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
+ def __init__(self, tools: List[Tool], tool_call_parser: str):
+ detector: Type[BaseFormatDetector] = None
+ detector_class = self.ToolCallParserEnum.get(tool_call_parser)
+ if detector_class:
+ detector = detector_class()
else:
- raise ValueError("Tool Call Parser Not Given!")
+ raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
- self.multi_format_parser = MultiFormatParser(detectors)
+ self.detector = detector
self.tools = tools
- def parse_non_stream(self, full_text: str):
+ def has_tool_call(self, text: str) -> bool:
"""
- Non-streaming call: one-time parsing
+ Check if the given text contains a tool call in the format supported by this parser.
+ This delegates to the detector's implementation.
+
+ Args:
+ text: The text to check for tool calls
+
+ Returns:
+ True if the text contains a tool call, False otherwise
"""
- full_normal_text, calls = self.multi_format_parser.parse_once(full_text, self.tools)
- return full_normal_text, calls
+ if not self.tools:
+ return False
+ return self.detector.has_tool_call(text)
- def parse_stream_chunk(self, chunk_text: str):
+ def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
- Streaming call: incremental parsing
+ One-time parsing of the full text to extract tool calls.
+
+ Args:
+ full_text: The complete text to parse
+
+ Returns:
+ A tuple containing:
+ - The remaining text after parsing that was not consumed by the detector (can be treated as normal text)
+ - A list of tool calls parsed from the text
+ """
+ if not self.tools:
+ return full_text, []
+ parsed_result = self.detector.detect_and_parse(full_text, self.tools)
+ tool_call_list = parsed_result.calls
+ if tool_call_list:
+ return parsed_result.normal_text, tool_call_list
+ else:
+ return full_text, []
+
+ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
"""
- normal_text, calls = self.multi_format_parser.parse_streaming_increment(chunk_text, self.tools)
- return normal_text, calls
+ Streaming incremental parsing of chunks of text as they arrive.
+
+ Args:
+ chunk_text: The new chunk of text to parse
+
+ Returns:
+ A tuple containing:
+ - The normal text that should be displayed to the user
+ - A list of tool calls parsed from the chunk
+ """
+ if not self.tools:
+ return chunk_text, []
+ final_normal_text = ""
+ final_calls = []
+
+ sp_result = self.detector.parse_streaming_increment(chunk_text, self.tools)
+ if sp_result.normal_text:
+ final_normal_text = sp_result.normal_text
+ if sp_result.calls:
+ final_calls.extend(sp_result.calls)
+ final_normal_text = sp_result.normal_text
+
+ return final_normal_text, final_calls