diff --git a/docs/features/reasoning_output.md b/docs/features/reasoning_output.md index a8d6fe44d69..0279336dbd6 100644 --- a/docs/features/reasoning_output.md +++ b/docs/features/reasoning_output.md @@ -11,6 +11,9 @@ Reasoning models return an additional `reasoning_content` field in their output, | baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✅ | ❌ |"chat_template_kwargs":{"enable_thinking": true/false}| | baidu/ERNIE-4.5-21B-A3B-Thinking | ernie-x1 | ✅ Not supported for turning off | ✅|❌| | baidu/ERNIE-4.5-VL-28B-A3B-Thinking | ernie-45-vl-thinking | ✅ Not recommended to turn off | ✅|"chat_template_kwargs": {"options": {"thinking_mode": "open/close"}}| +| unsloth/DeepSeek-V3.1-BF16 | deepseek | ❌ (thinking mode off by default) | ✅|❌| +| unsloth/DeepSeek-V3-0324-BF16 | deepseek | ✅ (thinking mode on by default) | ✅|❌| +| unsloth/DeepSeek-R1-BF16 | deepseek | ✅ (thinking mode on by default) | ✅|❌| The reasoning model requires a specified parser to extract reasoning content. Referring to the `thinking switch parameters` of each model can turn off the model's thinking mode. @@ -159,3 +162,31 @@ Model output example } ``` More reference documentation related to tool calling usage: [Tool Calling](./tool_calling.md) + +## Error Handling and Invalid Format + +The DeepSeek reasoning parser handles various invalid or incomplete format scenarios gracefully: + +### Missing Start Tag +If the model output contains only the end tag without the start tag: +- **Input**: `abcxyz` +- **Output**: `reasoning_content="abc"`, `content="xyz"` +- The parser extracts content before the end tag as reasoning, and content after as reply. + +### Missing End Tag +If the model output contains only the start tag without the end tag: +- **Input**: `abc` +- **Output**: `reasoning_content="abc"`, `content=None` +- The parser treats all content as reasoning when the end tag is missing. + +### No Reasoning Tags (Thinking Mode Off) +When thinking mode is disabled (e.g., DeepSeek-V3.1 by default): +- **Input**: `direct response` +- **Output**: `reasoning_content=None`, `content="direct response"` +- The parser treats the entire output as reply content. + +### Protocol Violation with Tool Calls +If there is non-whitespace content between the reasoning end tag and tool calls: +- **Input**: `thinking\n\nABC\n<|tool▁calls▁begin|>...` +- **Output**: Tool calls are not parsed, entire output is returned as `content` +- This ensures tool calls are only parsed when they immediately follow reasoning content. diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index bee80d64474..91748e54142 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -8,6 +8,13 @@ This document describes how to configure the server in FastDeploy to use the too | baidu/ERNIE-4.5-21B-A3B-Thinking | ernie-x1 | | baidu/ERNIE-4.5-VL-28B-A3B-Thinking | ernie-45-vl-thinking | +## Tool Call parser for DeepSeek series models +| Model Name | Parser Name | +|---------------|-------------| +| unsloth/DeepSeek-V3.1-BF16 | deepseek | +| unsloth/DeepSeek-V3-0324-BF16 | deepseek | +| unsloth/DeepSeek-R1-BF16 | deepseek | + ## Quickstart ### Starting FastDeploy with Tool Calling Enabled. @@ -90,6 +97,27 @@ The example output is as follows. It shows that the model's output of the though } ``` +## Error Handling and Invalid Format + +The DeepSeek tool parser handles various invalid or incomplete format scenarios: + +### Protocol Violation +If there is non-whitespace content between the reasoning end tag (``) and tool calls: +- **Input**: `thinking\n\nABC\n<|tool▁calls▁begin|>...` +- **Output**: `tools_called=False`, `tool_calls=None`, `content=` +- Tool calls are not parsed when protocol is violated. The entire output is returned as content. + +### Malformed JSON Arguments +If the tool call arguments contain invalid JSON: +- **Input**: `<|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit":}<|tool▁call▁end|>` +- **Output**: The parser attempts to use `partial_json_parser` to recover valid JSON. If recovery fails, it returns an empty object `{}` or the raw text. +- This ensures graceful handling of incomplete JSON during streaming. + +### Missing Tool Call End Tag +If a tool call is incomplete (missing end tag): +- **Input**: `<|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京"` +- **Output**: In streaming mode, the parser waits for more data. In non-streaming mode, it attempts to extract what's available. + ## Parallel Tool Calls If the model can generate parallel tool calls, FastDeploy will return a list: ```bash diff --git a/docs/zh/features/reasoning_output.md b/docs/zh/features/reasoning_output.md index c43f9fb4edd..150de17b31c 100644 --- a/docs/zh/features/reasoning_output.md +++ b/docs/zh/features/reasoning_output.md @@ -11,6 +11,9 @@ | baidu/ERNIE-4.5-VL-28B-A3B-Paddle | ernie-45-vl | ✅ | ❌ |"chat_template_kwargs":{"enable_thinking": true/false}| | baidu/ERNIE-4.5-21B-A3B-Thinking | ernie-x1 | ✅不支持关思考 | ✅|❌| | baidu/ERNIE-4.5-VL-28B-A3B-Thinking | ernie-45-vl-thinking | ✅不推荐关闭 | ✅|"chat_template_kwargs": {"options": {"thinking_mode": "open/close"}}| +| unsloth/DeepSeek-V3.1-BF16 | deepseek | ❌ (默认关闭思考模式) | ✅|❌| +| unsloth/DeepSeek-V3-0324-BF16 | deepseek | ✅ (默认开启思考模式) | ✅|❌| +| unsloth/DeepSeek-R1-BF16 | deepseek | ✅ (默认开启思考模式) | ✅|❌| 思考模型需要指定解析器,以便于对思考内容进行解析. 参考各个模型的 `思考开关控制参数` 可以关闭模型思考模式. @@ -158,3 +161,31 @@ curl -X POST "http://0.0.0.0:8390/v1/chat/completions" \ } ``` 更多工具调用相关的使用参考文档 [Tool Calling](./tool_calling.md) + +## 错误处理和格式不合法情况 + +DeepSeek 推理解析器能够优雅地处理各种格式不合法或不完整的情况: + +### 缺少起始标签 +如果模型输出只包含结束标签而没有起始标签: +- **输入**: `abcxyz` +- **输出**: `reasoning_content="abc"`, `content="xyz"` +- 解析器会将结束标签之前的内容提取为思考内容,之后的内容提取为回复内容。 + +### 缺少结束标签 +如果模型输出只包含起始标签而没有结束标签: +- **输入**: `abc` +- **输出**: `reasoning_content="abc"`, `content=None` +- 解析器会将所有内容视为思考内容。 + +### 无思考标签(思考模式关闭) +当思考模式被关闭时(例如 DeepSeek-V3.1 默认关闭): +- **输入**: `direct response` +- **输出**: `reasoning_content=None`, `content="direct response"` +- 解析器会将整个输出视为回复内容。 + +### 协议不规范(工具调用前有非空白字符) +如果思考结束标签和工具调用之间存在非空白字符: +- **输入**: `thinking\n\nABC\n<|tool▁calls▁begin|>...` +- **输出**: 工具调用不会被解析,整个输出作为 `content` 返回 +- 这确保了只有在工具调用紧跟在思考内容之后时才会被解析。 diff --git a/docs/zh/features/tool_calling.md b/docs/zh/features/tool_calling.md index dbc99fa7308..c165c8ddc7c 100644 --- a/docs/zh/features/tool_calling.md +++ b/docs/zh/features/tool_calling.md @@ -8,6 +8,13 @@ | baidu/ERNIE-4.5-21B-A3B-Thinking | ernie-x1 | | baidu/ERNIE-4.5-VL-28B-A3B-Thinking | ernie-45-vl-thinking | +## DeepSeek系列模型配套工具解释器 +| 模型名称 | 解析器名称 | +|---------------|-------------| +| unsloth/DeepSeek-V3.1-BF16 | deepseek | +| unsloth/DeepSeek-V3-0324-BF16 | deepseek | +| unsloth/DeepSeek-R1-BF16 | deepseek | + ## 快速开始 ### 启动包含解析器的FastDeploy @@ -92,6 +99,27 @@ curl -X POST http://0.0.0.0:8000/v1/chat/completions \ ] } ``` +## 错误处理和格式不合法情况 + +DeepSeek 工具解析器能够处理各种格式不合法或不完整的情况: + +### 协议不规范 +如果思考结束标签(``)和工具调用之间存在非空白字符: +- **输入**: `thinking\n\nABC\n<|tool▁calls▁begin|>...` +- **输出**: `tools_called=False`, `tool_calls=None`, `content=<完整输出>` +- 当协议不规范时,工具调用不会被解析,整个输出作为 content 返回。 + +### JSON 参数格式错误 +如果工具调用的参数包含无效的 JSON: +- **输入**: `<|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit":}<|tool▁call▁end|>` +- **输出**: 解析器会尝试使用 `partial_json_parser` 来恢复有效的 JSON。如果恢复失败,会返回空对象 `{}` 或原始文本。 +- 这确保了在流式输出过程中能够优雅地处理不完整的 JSON。 + +### 缺少工具调用结束标签 +如果工具调用不完整(缺少结束标签): +- **输入**: `<|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京"` +- **输出**: 在流式模式下,解析器会等待更多数据。在非流式模式下,会尝试提取可用的内容。 + ## 并行工具调用 如果模型能够生成多个并行的工具调用,FastDeploy 会返回一个列表: diff --git a/fastdeploy/entrypoints/openai/tool_parsers/deepseek_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/deepseek_tool_parser.py new file mode 100644 index 00000000000..0744b3f5c0f --- /dev/null +++ b/fastdeploy/entrypoints/openai/tool_parsers/deepseek_tool_parser.py @@ -0,0 +1,357 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Optional, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from fastdeploy.utils import data_processor_logger + + +def random_tool_call_id() -> str: + """Generate a random tool call ID""" + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +@ToolParserManager.register_module(["deepseek", "deepseek-r1", "deepseek-v3.1", "deepseek-v3-0324"]) +class DeepSeekToolParser(ToolParser): + """ + DeepSeek 系列模型的工具调用解析器(支持 V3.1、V3-0324、R1 三种模型) + + 支持的格式: + - V3.1: <|tool▁call▁begin|>function_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|> + - V3-0324/R1: <|tool▁call▁begin|>function<|tool▁sep|>function_name\n```json\n{"arg": "value"}\n```<|tool▁call▁end|> + """ + + def __init__(self, tokenizer, model_name=None): + super().__init__(tokenizer) + + self.model_name = model_name or "" + self.buffer: str = "" + + # 特殊标记 + self.tool_calls_begin_token = "<|tool▁calls▁begin|>" + self.tool_calls_end_token = "<|tool▁calls▁end|>" + self.tool_call_begin_token = "<|tool▁call▁begin|>" + self.tool_call_end_token = "<|tool▁call▁end|>" + self.tool_sep_token = "<|tool▁sep|>" + + # 获取 token IDs + self.tool_calls_begin_token_id = self.vocab.get(self.tool_calls_begin_token) + self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) + self.tool_call_begin_token_id = self.vocab.get(self.tool_call_begin_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + self.tool_sep_token_id = self.vocab.get(self.tool_sep_token) + + if self.tool_calls_begin_token_id is None or self.tool_call_begin_token_id is None: + raise RuntimeError( + "DeepSeek Tool parser could not locate tool call tokens in the tokenizer!" + ) + + # 检测模型版本 + self.is_v31 = self._detect_model_version() + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser constructor during construction." + ) + + def _detect_model_version(self) -> bool: + """检测模型版本:V3.1 还是 V3-0324/R1""" + if "v3.1" in self.model_name.lower(): + return True + elif "v3-0324" in self.model_name.lower() or "r1" in self.model_name.lower(): + return False + # 默认使用 V3.1 格式 + return True + + def detect_output_stage(self, prompt_token_ids: Sequence[int]) -> str: + """ + 根据进入模型的 prompt_token_ids,判断接下来模型输出是否处于工具调用阶段 + """ + if self.tool_calls_begin_token_id in prompt_token_ids: + return "TOOL_CALL_STAGE" + return "CONTENT_STAGE" + + def extract_tool_calls( + self, model_output: str, request: ChatCompletionRequest, output_stage: Optional[str] = None + ) -> ExtractedToolCallInformation: + """ + 从完整的模型输出中提取工具调用(非流式场景) + """ + try: + # 检查是否有工具调用标记 + if self.tool_calls_begin_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + + # 检查 与工具调用之间是否有非空白字符 + reasoning_end = "" + if reasoning_end in model_output: + reasoning_end_pos = model_output.find(reasoning_end) + after_reasoning = model_output[reasoning_end_pos + len(reasoning_end):] + tool_calls_begin_pos = after_reasoning.find(self.tool_calls_begin_token) + if tool_calls_begin_pos > 0: + # 检查中间是否有非空白字符 + between_text = after_reasoning[:tool_calls_begin_pos] + if between_text.strip() and not between_text.strip().isspace(): + # 有非空白字符,协议不规范,不解析工具调用 + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + + tool_calls = [] + + if self.is_v31: + # V3.1 格式:<|tool▁call▁begin|>function_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|> + # 转义特殊标记中的 | 字符(在正则表达式中 | 是特殊字符) + begin_escaped = self.tool_call_begin_token.replace("|", r"\|") + sep_escaped = self.tool_sep_token.replace("|", r"\|") + end_escaped = self.tool_call_end_token.replace("|", r"\|") + pattern = f"{begin_escaped}(?P[^<]+?){sep_escaped}(?P.*?){end_escaped}" + else: + # V3-0324/R1 格式:<|tool▁call▁begin|>function<|tool▁sep|>function_name\n```json\n{"arg": "value"}\n```<|tool▁call▁end|> + begin_escaped = self.tool_call_begin_token.replace("|", r"\|") + sep_escaped = self.tool_sep_token.replace("|", r"\|") + end_escaped = self.tool_call_end_token.replace("|", r"\|") + # 注意:代码块标记 ``` 需要转义 + pattern = f"{begin_escaped}(?P[^<]+?){sep_escaped}(?P[^\\n]+?)\\n```json\\n(?P.*?)\\n```\\n{end_escaped}" + + matches = re.finditer(pattern, model_output, re.DOTALL) + + for match in matches: + function_name = match.group("function_name").strip() + function_arguments = match.group("function_arguments").strip() + + # 解析参数 + try: + if function_arguments: + args_dict = json.loads(function_arguments) + else: + args_dict = {} + except json.JSONDecodeError: + # 尝试使用 partial_json_parser + try: + args_dict = partial_json_parser.loads(function_arguments, flags=Allow.ALL) + except: + args_dict = {} + + args_str = json.dumps(args_dict, ensure_ascii=False) if args_dict else "{}" + + tool_calls.append( + ToolCall( + type="function", + id=random_tool_call_id(), + function=FunctionCall( + name=function_name, + arguments=args_str, + ), + ) + ) + + if tool_calls: + return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="") + else: + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + + except Exception as e: + data_processor_logger.error(f"Error in extracting tool calls from response: {str(e)}") + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + 从增量消息中提取工具调用(流式场景) + """ + try: + # 如果没有工具调用标记,返回 None + if self.tool_calls_begin_token_id not in current_token_ids: + return None + + # 累积到 buffer + self.buffer += delta_text + + # 检测新的工具调用开始 + if self.tool_call_begin_token_id in delta_token_ids: + self.current_tool_id = ( + max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 + ) + self.current_tool_name_sent = False + if len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") + + # 1. 尝试提取工具名称 + if not self.current_tool_name_sent: + # 查找工具名称:在 <|tool▁call▁begin|> 和 <|tool▁sep|> 之间 + begin_pos = self.buffer.find(self.tool_call_begin_token) + if begin_pos != -1: + after_begin = self.buffer[begin_pos + len(self.tool_call_begin_token):] + sep_pos = after_begin.find(self.tool_sep_token) + if sep_pos != -1: + # 提取分隔符之前的内容 + tool_type_or_name = after_begin[:sep_pos].strip() + after_sep = after_begin[sep_pos + len(self.tool_sep_token):] + + # 判断格式:如果是 V3-0324/R1 且提取到的是 "function",则从分隔符后提取函数名 + if not self.is_v31 and tool_type_or_name == "function": + # V3-0324/R1 格式:提取分隔符后、换行符前的内容 + newline_pos = after_sep.find("\n") + if newline_pos != -1: + function_name = after_sep[:newline_pos].strip() + else: + # 如果还没有换行符,暂时返回 None,等待更多数据 + return None + else: + # V3.1 格式:分隔符前的内容就是函数名 + function_name = tool_type_or_name + + if function_name: + # 创建 DeltaMessage + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), + ) + ] + ) + # 从 buffer 中移除已处理的部分 + if not self.is_v31 and tool_type_or_name == "function": + # V3-0324/R1:需要移除到换行符之后(包括换行符) + processed_end = begin_pos + len(self.tool_call_begin_token) + sep_pos + len(self.tool_sep_token) + newline_pos + 1 + else: + # V3.1:移除到分隔符之后 + processed_end = begin_pos + len(self.tool_call_begin_token) + sep_pos + len(self.tool_sep_token) + self.buffer = self.buffer[processed_end:] + self.current_tool_name_sent = True + return delta + + # 2. 处理参数部分 + if self.current_tool_name_sent: + # 检查是否到达工具调用结束标记 + if self.tool_call_end_token_id in delta_token_ids: + # 工具调用结束,提取完整参数 + end_pos = self.buffer.find(self.tool_call_end_token) + if end_pos != -1: + args_text = self.buffer[:end_pos].strip() + + # 对于 V3-0324/R1,需要从代码块中提取 JSON + if not self.is_v31: + # 移除 ```json 和 ``` 标记 + args_text = re.sub(r"^```json\s*", "", args_text, flags=re.MULTILINE) + args_text = re.sub(r"\s*```\s*$", "", args_text, flags=re.MULTILINE) + args_text = args_text.strip() + + if args_text: + try: + # 尝试解析完整 JSON + args_dict = json.loads(args_text) + args_str = json.dumps(args_dict, ensure_ascii=False) + except json.JSONDecodeError: + # 使用 partial_json_parser + try: + args_dict = partial_json_parser.loads(args_text, flags=Allow.ALL) + args_str = json.dumps(args_dict, ensure_ascii=False) + except: + args_str = args_text + + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=args_str).model_dump(exclude_none=True), + ) + ] + ) + # 清理 buffer + self.buffer = self.buffer[end_pos + len(self.tool_call_end_token):] + return delta + else: + # 流式输出参数 + # 对于 V3-0324/R1,需要跳过代码块标记 + args_text = self.buffer + if not self.is_v31: + # 移除开头的 ```json 标记(如果存在) + args_text = re.sub(r"^```json\s*", "", args_text, flags=re.MULTILINE) + + if args_text.strip(): + # 尝试解析部分 JSON + try: + # 使用 partial_json_parser 解析部分 JSON + args_dict = partial_json_parser.loads(args_text, flags=Allow.ALL) + args_str = json.dumps(args_dict, ensure_ascii=False) + except: + # 如果解析失败,直接使用原始文本 + args_str = args_text + + # 计算增量部分(只返回新增的部分) + if len(self.streamed_args_for_tool) > self.current_tool_id: + prev_args = self.streamed_args_for_tool[self.current_tool_id] + if args_str.startswith(prev_args): + new_args = args_str[len(prev_args):] + if new_args: + self.streamed_args_for_tool[self.current_tool_id] = args_str + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=new_args).model_dump(exclude_none=True), + ) + ] + ) + else: + # 第一次收到参数 + self.streamed_args_for_tool.append(args_str) + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=args_str).model_dump(exclude_none=True), + ) + ] + ) + + return None + + except Exception as e: + data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}") + return None diff --git a/fastdeploy/reasoning/deepseek_reasoning_parser.py b/fastdeploy/reasoning/deepseek_reasoning_parser.py new file mode 100644 index 00000000000..64da44391ba --- /dev/null +++ b/fastdeploy/reasoning/deepseek_reasoning_parser.py @@ -0,0 +1,192 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from collections.abc import Sequence +from typing import Optional, Union + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager + + +@ReasoningParserManager.register_module(["deepseek", "deepseek-r1", "deepseek-v3.1", "deepseek-v3-0324"]) +class DeepSeekReasoningParser(ReasoningParser): + """ + Reasoning parser for DeepSeek models (V3.1, V3-0324, R1). + Extracts reasoning content and response content from model output. + """ + + def __init__(self, tokenizer, model_name=None): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " "constructor during construction." + ) + + # Get special token IDs + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + + if self.think_end_token_id is None: + raise RuntimeError( + "DeepSeek reasoning parser could not locate think end " "tokens in the tokenizer!" + ) + + # Detect model version to determine if reasoning toggle is supported + self.model_name = model_name or "" + self.supports_reasoning_toggle = "v3.1" in self.model_name.lower() + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + """Check if reasoning content has ended (check for token).""" + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """Extract content token IDs after .""" + if self.think_end_token_id not in input_ids: + return input_ids + + # Find position of + end_index = input_ids.index(self.think_end_token_id) + # Return all token IDs after the end token + return input_ids[end_index + 1 :] + + def detect_output_stage(self, prompt_token_ids: Sequence[int]) -> str: + """Detect output stage based on prompt token IDs.""" + # Check if prompt contains start token + if self.think_start_token_id is not None and self.think_start_token_id in prompt_token_ids: + # Check if thinking stage has ended + if self.think_end_token_id is not None and self.think_end_token_id in prompt_token_ids: + # Thinking ended, enter content stage + return "CONTENT_STAGE" + else: + # Still in thinking stage + return "REASONING_STAGE" + else: + # No thinking tokens, possibly reasoning toggle is off + # Default to content stage + return "CONTENT_STAGE" + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest, output_stage: Optional[str] = None + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content and response content from complete model output (non-streaming). + Supports formats: abcxyz, abcxyz, or xyz. + """ + # Check for start token + if self.think_start_token in model_output: + # Standard format: contentanswer + # Remove start token + model_output_parts = model_output.partition(self.think_start_token) + model_output = model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + + # Check for end token + if self.think_end_token not in model_output: + # Only start token, no end token: treat entire content as reasoning + return model_output, None + + # Extract reasoning and response content + reasoning_content, _, content = model_output.partition(self.think_end_token) + + # Strip whitespace but preserve newlines + final_content = content.strip() if content.strip() else None + return reasoning_content, final_content + + # Check for end token (but no start token) + if self.think_end_token in model_output: + # Missing start token format: contentanswer + parts = model_output.split(self.think_end_token, 1) + + if len(parts) == 2: + reasoning_content = parts[0].strip() + final_content = parts[1].strip() if parts[1].strip() else None + return reasoning_content, final_content + + # No thinking tokens mode + if output_stage == "REASONING_STAGE": + # If detected as reasoning stage but no end token, treat as protocol error + # Return entire output as reasoning_content + return model_output, None + else: + # Reasoning toggle off or in content stage: return entire output as content + return None, model_output + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + output_stage: Optional[str] = None, + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from incremental messages (streaming). + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + """ + # Ignore single token + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + # If delta contains + if self.think_end_token_id in delta_token_ids: + # If delta contains both and + if self.think_start_token_id and self.think_start_token_id in delta_token_ids: + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + if start_index != -1 and end_index != -1: + reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) + # If in delta but in previous + else: + end_index = delta_text.find(self.think_end_token) + if end_index != -1: + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + # Strip whitespace but preserve newlines + content = content if content.strip() else None + return DeltaMessage(reasoning_content=reasoning_content, content=content) + + # If in previous, already in content stage + if self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + # If in previous, still in thinking stage + if self.think_start_token_id and self.think_start_token_id in previous_token_ids: + return DeltaMessage(reasoning_content=delta_text) + + # If in delta + if self.think_start_token_id and self.think_start_token_id in delta_token_ids: + start_index = delta_text.find(self.think_start_token) + if start_index != -1: + reasoning_content = delta_text[start_index + len(self.think_start_token) :] + return DeltaMessage(reasoning_content=reasoning_content, content=None) + + # Default: determine based on output_stage + # If no tokens seen, possibly reasoning toggle is off + if output_stage == "CONTENT_STAGE": + # In content stage, return delta as content + return DeltaMessage(content=delta_text) + else: + # In thinking stage or unknown, return delta as reasoning_content + # Will be handled correctly if appears later + return DeltaMessage(reasoning_content=delta_text) + diff --git a/tests/entrypoints/openai/tool_parsers/test_deepseek_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_deepseek_tool_parser.py new file mode 100644 index 00000000000..e4cb9da5299 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_deepseek_tool_parser.py @@ -0,0 +1,584 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +import unittest + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.entrypoints.openai.tool_parsers.deepseek_tool_parser import DeepSeekToolParser + + +class DummyTokenizer: + """Minimal tokenizer with vocab for testing.""" + + def __init__(self): + self.vocab = { + "": 128798, + "": 128799, + "<|tool▁calls▁begin|>": 128806, + "<|tool▁calls▁end|>": 128807, + "<|tool▁call▁begin|>": 128808, + "<|tool▁call▁end|>": 128809, + "<|tool▁sep|>": 128814, + } + + def get_vocab(self): + """Return vocab dict for testing.""" + return self.vocab + + +class TestDeepSeekToolParserV31(unittest.TestCase): + """Test tool parser for DeepSeek-V3.1 format.""" + + def setUp(self): + self.tokenizer = DummyTokenizer() + self.parser = DeepSeekToolParser(tokenizer=self.tokenizer, model_name="deepseek-v3.1") + self.request = ChatCompletionRequest( + model="deepseek-v3.1", + messages=[{"role": "user", "content": "What's the weather in Beijing?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + }, + }, + } + ], + ) + + # ---- Non-streaming parsing ---- + def test_batch_single_tool_call(self): + """Test single tool call (V3.1 format).""" + text = '需要查询天气\n\n<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit": "c"}<|tool▁call▁end|><|tool▁calls▁end|>' + result = self.parser.extract_tool_calls(text, self.request) + self.assertTrue(result.tools_called) + self.assertIsNotNone(result.tool_calls) + self.assertEqual(len(result.tool_calls), 1) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertIn("location", result.tool_calls[0].function.arguments) + self.assertEqual(result.content, "") + + def test_batch_parallel_tool_calls(self): + """Test parallel tool calls (V3.1 format).""" + text = ( + '需要查询多个信息\n\n' + '<|tool▁calls▁begin|>' + '<|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit": "c"}<|tool▁call▁end|>' + '<|tool▁call▁begin|>get_time<|tool▁sep|>{"timezone": "Asia/Shanghai"}<|tool▁call▁end|>' + '<|tool▁calls▁end|>' + ) + result = self.parser.extract_tool_calls(text, self.request) + self.assertTrue(result.tools_called) + self.assertIsNotNone(result.tool_calls) + self.assertEqual(len(result.tool_calls), 2) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertEqual(result.tool_calls[1].function.name, "get_time") + + def test_batch_no_tool_calls(self): + """Test no tool calls.""" + text = "这是普通回复\n\n这是回复内容" + result = self.parser.extract_tool_calls(text, self.request) + self.assertFalse(result.tools_called) + self.assertIsNone(result.tool_calls) + self.assertEqual(result.content, text) + + def test_batch_invalid_format_with_content_before_tool(self): + """Test invalid format: non-whitespace content between reasoning end and tool calls.""" + text = '思考内容\n\nABC\n<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京"}<|tool▁call▁end|><|tool▁calls▁end|>' + result = self.parser.extract_tool_calls(text, self.request) + self.assertFalse(result.tools_called) + self.assertEqual(result.content, text) + + def test_batch_partial_json(self): + """Test incomplete JSON arguments.""" + text = '需要查询天气\n\n<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit": "c"}<|tool▁call▁end|><|tool▁calls▁end|>' + result = self.parser.extract_tool_calls(text, self.request) + self.assertTrue(result.tools_called) + self.assertIsNotNone(result.tool_calls) + + # ---- Streaming parsing ---- + def test_streaming_tool_name(self): + """Test streaming tool name extraction.""" + # Reset parser state + self.parser.buffer = "" + self.parser.current_tool_name_sent = False + self.parser.streamed_args_for_tool = [] + self.parser.current_tool_id = -1 + + # Step 1: Receive reasoning content and tool call begin token + previous_text = "需要查询天气\n\n" + current_text = previous_text + "<|tool▁calls▁begin|>" + delta_text = "<|tool▁calls▁begin|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[128806], + delta_token_ids=[128806], + request=self.request, + ) + # No tool name yet, should return None + self.assertIsNone(msg) + + # Step 2: Receive tool call begin token and partial tool name + previous_text = current_text + current_text = previous_text + "<|tool▁call▁begin|>get" + delta_text = "<|tool▁call▁begin|>get" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806], + current_token_ids=[128806, 128808, 200], + delta_token_ids=[128808, 200], + request=self.request, + ) + # No separator yet, should return None + self.assertIsNone(msg) + + # Step 3: Receive complete tool name and separator + previous_text = current_text + current_text = previous_text + "_weather<|tool▁sep|>" + delta_text = "_weather<|tool▁sep|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 200], + current_token_ids=[128806, 128808, 200, 201, 128814], + delta_token_ids=[201, 128814], + request=self.request, + ) + # Should extract tool name now + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertEqual(len(msg.tool_calls), 1) + self.assertEqual(msg.tool_calls[0].function.name, "get_weather") + self.assertIsNone(msg.tool_calls[0].function.arguments) + + def test_streaming_tool_arguments(self): + """Test streaming tool arguments extraction.""" + # Reset parser state + self.parser.buffer = "" + self.parser.current_tool_name_sent = False + self.parser.streamed_args_for_tool = [] + self.parser.current_tool_id = -1 + + # Step 1: Set tool name sent state (simulate tool name already extracted) + previous_text = "需要查询天气\n\n" + current_text = previous_text + "<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>" + delta_text = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[128806, 128808, 200, 201, 128814], + delta_token_ids=[128806, 128808, 200, 201, 128814], + request=self.request, + ) + # Should extract tool name now + self.assertIsNotNone(msg) + self.assertEqual(msg.tool_calls[0].function.name, "get_weather") + self.assertTrue(self.parser.current_tool_name_sent) + + # Step 2: Receive partial JSON arguments + previous_text = current_text + current_text = previous_text + '{"location": "' + delta_text = '{"location": "' + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 200, 201, 128814], + current_token_ids=[128806, 128808, 200, 201, 128814, 300, 301], + delta_token_ids=[300, 301], + request=self.request, + ) + # Should return incremental arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertEqual(len(msg.tool_calls), 1) + self.assertIn("location", msg.tool_calls[0].function.arguments) + + # Step 3: Receive more arguments + previous_text = current_text + current_text = previous_text + '北京", "unit": "c"}' + delta_text = '北京", "unit": "c"}' + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 200, 201, 128814, 300, 301], + current_token_ids=[128806, 128808, 200, 201, 128814, 300, 301, 302, 303], + delta_token_ids=[302, 303], + request=self.request, + ) + # Should return incremental arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertIn("unit", msg.tool_calls[0].function.arguments) + + # Step 4: Receive tool call end token + previous_text = current_text + current_text = previous_text + "<|tool▁call▁end|>" + delta_text = "<|tool▁call▁end|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 200, 201, 128814, 300, 301, 302, 303], + current_token_ids=[128806, 128808, 200, 201, 128814, 300, 301, 302, 303, 128809], + delta_token_ids=[128809], + request=self.request, + ) + # Should return complete arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + arguments = json.loads(msg.tool_calls[0].function.arguments) + self.assertEqual(arguments["location"], "北京") + self.assertEqual(arguments["unit"], "c") + + +class TestDeepSeekToolParserV30324(unittest.TestCase): + """Test tool parser for DeepSeek-V3-0324/R1 format.""" + + def setUp(self): + self.tokenizer = DummyTokenizer() + self.parser = DeepSeekToolParser(tokenizer=self.tokenizer, model_name="deepseek-v3-0324") + self.request = ChatCompletionRequest( + model="deepseek-v3-0324", + messages=[{"role": "user", "content": "What's the weather in Beijing?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + }, + }, + } + ], + ) + + # ---- Non-streaming parsing ---- + def test_batch_single_tool_call(self): + """Test single tool call (V3-0324/R1 format).""" + text = ( + '需要查询天气\n\n' + '<|tool▁calls▁begin|>' + '<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n' + '```json\n' + '{"location": "北京", "unit": "c"}\n' + '```\n' + '<|tool▁call▁end|>' + '<|tool▁calls▁end|>' + ) + result = self.parser.extract_tool_calls(text, self.request) + self.assertTrue(result.tools_called) + self.assertIsNotNone(result.tool_calls) + self.assertEqual(len(result.tool_calls), 1) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertIn("location", result.tool_calls[0].function.arguments) + + def test_batch_parallel_tool_calls(self): + """Test parallel tool calls (V3-0324/R1 format).""" + text = ( + '需要查询多个信息\n\n' + '<|tool▁calls▁begin|>' + '<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n' + '```json\n' + '{"location": "北京", "unit": "c"}\n' + '```\n' + '<|tool▁call▁end|>' + '<|tool▁call▁begin|>function<|tool▁sep|>get_time\n' + '```json\n' + '{"timezone": "Asia/Shanghai"}\n' + '```\n' + '<|tool▁call▁end|>' + '<|tool▁calls▁end|>' + ) + result = self.parser.extract_tool_calls(text, self.request) + self.assertTrue(result.tools_called) + self.assertIsNotNone(result.tool_calls) + self.assertEqual(len(result.tool_calls), 2) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + self.assertEqual(result.tool_calls[1].function.name, "get_time") + + def test_batch_no_tool_calls(self): + """Test no tool calls.""" + text = "这是普通回复\n\n这是回复内容" + result = self.parser.extract_tool_calls(text, self.request) + self.assertFalse(result.tools_called) + self.assertIsNone(result.tool_calls) + self.assertEqual(result.content, text) + + def test_batch_invalid_format_with_content_before_tool(self): + """Test invalid format: non-whitespace content between reasoning end and tool calls.""" + text = ( + '思考内容\n\nABC\n' + '<|tool▁calls▁begin|>' + '<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n' + '```json\n' + '{"location": "北京"}\n' + '```\n' + '<|tool▁call▁end|>' + '<|tool▁calls▁end|>' + ) + result = self.parser.extract_tool_calls(text, self.request) + self.assertFalse(result.tools_called) + self.assertEqual(result.content, text) + + # ---- Streaming parsing ---- + def test_streaming_tool_name_v30324(self): + """Test streaming tool name extraction (V3-0324/R1 format).""" + # Reset parser state + self.parser.buffer = "" + self.parser.current_tool_name_sent = False + self.parser.streamed_args_for_tool = [] + self.parser.current_tool_id = -1 + + # Step 1: Receive reasoning content and tool call begin token + previous_text = "需要查询天气\n\n" + current_text = previous_text + "<|tool▁calls▁begin|>" + delta_text = "<|tool▁calls▁begin|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[128806], + delta_token_ids=[128806], + request=self.request, + ) + # No tool name yet, should return None + self.assertIsNone(msg) + + # Step 2: Receive tool call begin token and "function" + previous_text = current_text + current_text = previous_text + "<|tool▁call▁begin|>function" + delta_text = "<|tool▁call▁begin|>function" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806], + current_token_ids=[128806, 128808, 300], + delta_token_ids=[128808, 300], + request=self.request, + ) + # No separator yet, should return None + self.assertIsNone(msg) + + # Step 3: Receive separator + previous_text = current_text + current_text = previous_text + "<|tool▁sep|>" + delta_text = "<|tool▁sep|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300], + current_token_ids=[128806, 128808, 300, 128814], + delta_token_ids=[128814], + request=self.request, + ) + # Detected "function" but no newline yet, should return None and wait for more data + self.assertIsNone(msg) + + # Step 4: Receive function name and newline + previous_text = current_text + current_text = previous_text + "get_weather\n" + delta_text = "get_weather\n" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300, 128814], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10], + delta_token_ids=[200, 201, 202, 10], + request=self.request, + ) + # Should extract tool name now + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertEqual(len(msg.tool_calls), 1) + self.assertEqual(msg.tool_calls[0].function.name, "get_weather") + self.assertIsNone(msg.tool_calls[0].function.arguments) + + def test_streaming_tool_arguments_v30324(self): + """Test streaming tool arguments extraction (V3-0324/R1 format).""" + # Reset parser state + self.parser.buffer = "" + self.parser.current_tool_name_sent = False + self.parser.streamed_args_for_tool = [] + self.parser.current_tool_id = -1 + + # Step 1: Set tool name sent state (simulate tool name already extracted) + previous_text = "需要查询天气\n\n" + current_text = previous_text + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather\n" + delta_text = "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather\n" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10], + delta_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10], + request=self.request, + ) + # Should extract tool name now + self.assertIsNotNone(msg) + self.assertEqual(msg.tool_calls[0].function.name, "get_weather") + self.assertTrue(self.parser.current_tool_name_sent) + + # Step 2: Receive code block start token + previous_text = current_text + current_text = previous_text + "```json\n" + delta_text = "```json\n" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401], + delta_token_ids=[400, 401], + request=self.request, + ) + # Code block token should be skipped, return None (no argument content yet) + self.assertIsNone(msg) + + # Step 3: Receive partial JSON arguments + previous_text = current_text + current_text = previous_text + '{"location": "' + delta_text = '{"location": "' + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401, 500, 501], + delta_token_ids=[500, 501], + request=self.request, + ) + # Should return incremental arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertEqual(len(msg.tool_calls), 1) + self.assertIn("location", msg.tool_calls[0].function.arguments) + + # Step 4: Receive more arguments + previous_text = current_text + current_text = previous_text + '北京", "unit": "c"}\n' + delta_text = '北京", "unit": "c"}\n' + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401, 500, 501], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401, 500, 501, 502, 503, 10], + delta_token_ids=[502, 503, 10], + request=self.request, + ) + # Should return incremental arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + self.assertIn("unit", msg.tool_calls[0].function.arguments) + + # Step 5: Receive code block end token and tool call end token + previous_text = current_text + current_text = previous_text + "```\n<|tool▁call▁end|>" + delta_text = "```\n<|tool▁call▁end|>" + msg = self.parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401, 500, 501, 502, 503, 10], + current_token_ids=[128806, 128808, 300, 128814, 200, 201, 202, 10, 400, 401, 500, 501, 502, 503, 10, 402, 10, 128809], + delta_token_ids=[402, 10, 128809], + request=self.request, + ) + # Should return complete arguments + self.assertIsNotNone(msg) + self.assertIsNotNone(msg.tool_calls) + arguments = json.loads(msg.tool_calls[0].function.arguments) + self.assertEqual(arguments["location"], "北京") + self.assertEqual(arguments["unit"], "c") + + +class TestDeepSeekToolParserEdgeCases(unittest.TestCase): + """Test edge cases.""" + + def setUp(self): + self.tokenizer = DummyTokenizer() + self.parser = DeepSeekToolParser(tokenizer=self.tokenizer, model_name="deepseek-v3.1") + self.request = ChatCompletionRequest( + model="deepseek-v3.1", + messages=[{"role": "user", "content": "test"}], + tools=[ + { + "type": "function", + "function": { + "name": "test_function", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + def test_detect_output_stage_tool_call(self): + """Test detecting tool call stage.""" + prompt_token_ids = [128806] # Contains <|tool▁calls▁begin|> + stage = self.parser.detect_output_stage(prompt_token_ids) + self.assertEqual(stage, "TOOL_CALL_STAGE") + + def test_detect_output_stage_content(self): + """Test detecting content stage.""" + prompt_token_ids = [200, 201, 202] # No tool call tokens + stage = self.parser.detect_output_stage(prompt_token_ids) + self.assertEqual(stage, "CONTENT_STAGE") + + def test_empty_tool_calls_block(self): + """Test empty tool calls block.""" + text = "思考内容\n\n<|tool▁calls▁begin|><|tool▁calls▁end|>" + result = self.parser.extract_tool_calls(text, self.request) + self.assertFalse(result.tools_called) + + def test_malformed_json(self): + """Test malformed JSON.""" + text = '思考内容\n\n<|tool▁calls▁begin|><|tool▁call▁begin|>get_weather<|tool▁sep|>{"location": "北京", "unit":}<|tool▁call▁end|><|tool▁calls▁end|>' + result = self.parser.extract_tool_calls(text, self.request) + # Should handle gracefully, at least extract function name + if result.tools_called: + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests/reasoning/test_reasoning_parser.py b/tests/reasoning/test_reasoning_parser.py index e6deded445d..5ba4ae9b1a9 100644 --- a/tests/reasoning/test_reasoning_parser.py +++ b/tests/reasoning/test_reasoning_parser.py @@ -18,6 +18,7 @@ from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager +from fastdeploy.reasoning.deepseek_reasoning_parser import DeepSeekReasoningParser from fastdeploy.reasoning.ernie_45_vl_thinking_reasoning_parser import ( Ernie45VLThinkingReasoningParser, ) @@ -30,6 +31,7 @@ class DummyTokenizer: def __init__(self): self.vocab = { + "": 99, "": 100, "": 101, "": 102, @@ -481,5 +483,216 @@ def test_extract_reasoning_content(self): self.assertEqual(content, "\nactual response") +class TestDeepSeekReasoningParser(unittest.TestCase): + def setUp(self): + self.tokenizer = DummyTokenizer() + self.parser = DeepSeekReasoningParser(tokenizer=self.tokenizer, model_name="deepseek-v3.1") + self.request = ChatCompletionRequest( + model="deepseek-v3.1", messages=[{"role": "user", "content": "test message"}] + ) + + # ---- Non-streaming parsing ---- + def test_batch_standard_format(self): + """测试标准格式:abcxyz""" + text = "abcxyz" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertEqual(content, "xyz") + + def test_batch_no_start_tag(self): + """测试缺少起始标签的格式:abcxyz""" + text = "abcxyz" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertEqual(content, "xyz") + + def test_batch_no_reasoning_tags(self): + """测试无思考标签格式(思考开关关闭时)""" + text = "direct response" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertIsNone(reasoning) + self.assertEqual(content, "direct response") + + def test_batch_only_start_tag(self): + """测试只有起始标签,没有结束标签""" + text = "abc" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertIsNone(content) + + def test_batch_reasoning_with_newline(self): + """测试包含换行符的思考内容""" + text = "line1\nline2response" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "line1\nline2") + self.assertEqual(content, "response") + + def test_batch_empty_content(self): + """测试思考结束后没有回复内容""" + text = "abc" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertIsNone(content) + + def test_batch_content_with_whitespace(self): + """测试思考结束后只有空白字符""" + text = "abc\n\n " + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertIsNone(content) + + # ---- Streaming parsing ---- + def test_streaming_reasoning_content(self): + """测试流式输出思考内容""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="a", + delta_text="a", + previous_token_ids=[], + current_token_ids=[200], + delta_token_ids=[200], + ) + self.assertIsNotNone(msg) + self.assertEqual(msg.reasoning_content, "a") + self.assertIsNone(msg.content) + + def test_streaming_reasoning_end_tag(self): + """测试流式输出遇到结束标签""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="abc", + current_text="abc", + delta_text="", + previous_token_ids=[200, 201, 202], + current_token_ids=[200, 201, 202, self.parser.think_end_token_id], + delta_token_ids=[self.parser.think_end_token_id], + ) + self.assertIsNone(msg) # 单个结束标签应该被忽略 + + def test_streaming_reasoning_to_content(self): + """测试从思考阶段转换到回复阶段""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="abc", + current_text="abcxyz", + delta_text="xyz", + previous_token_ids=[200, 201, 202, self.parser.think_end_token_id], + current_token_ids=[200, 201, 202, self.parser.think_end_token_id, 110, 120, 130], + delta_token_ids=[110, 120, 130], + ) + self.assertIsNotNone(msg) + self.assertIsNone(msg.reasoning_content) + self.assertEqual(msg.content, "xyz") + + def test_streaming_reasoning_and_content_in_delta(self): + """测试 delta 中同时包含思考和回复内容""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="abcxyz", + delta_text="abcxyz", + previous_token_ids=[], + current_token_ids=[ + self.parser.think_start_token_id, + 200, + 201, + 202, + self.parser.think_end_token_id, + 110, + 120, + 130, + ], + delta_token_ids=[ + self.parser.think_start_token_id, + 200, + 201, + 202, + self.parser.think_end_token_id, + 110, + 120, + 130, + ], + ) + self.assertIsNotNone(msg) + self.assertEqual(msg.reasoning_content, "abc") + self.assertEqual(msg.content, "xyz") + + def test_streaming_reasoning_start_tag(self): + """测试流式输出遇到开始标签""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="abc", + delta_text="abc", + previous_token_ids=[], + current_token_ids=[self.parser.think_start_token_id, 200, 201, 202], + delta_token_ids=[self.parser.think_start_token_id, 200, 201, 202], + ) + self.assertIsNotNone(msg) + self.assertEqual(msg.reasoning_content, "abc") + self.assertIsNone(msg.content) + + def test_streaming_no_reasoning_tags(self): + """测试流式输出无思考标签(思考开关关闭)""" + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="direct", + delta_text="direct", + previous_token_ids=[], + current_token_ids=[200], + delta_token_ids=[200], + output_stage="CONTENT_STAGE", + ) + self.assertIsNotNone(msg) + self.assertIsNone(msg.reasoning_content) + self.assertEqual(msg.content, "direct") + + # ---- Stage detection ---- + def test_detect_output_stage_reasoning(self): + """测试检测思考阶段""" + prompt_token_ids = [self.parser.think_start_token_id] # 包含 开始标记 + stage = self.parser.detect_output_stage(prompt_token_ids) + self.assertEqual(stage, "REASONING_STAGE") + + def test_detect_output_stage_content(self): + """测试检测回复阶段""" + prompt_token_ids = [ + self.parser.think_start_token_id, + self.parser.think_end_token_id, + ] # 包含 + stage = self.parser.detect_output_stage(prompt_token_ids) + self.assertEqual(stage, "CONTENT_STAGE") + + def test_detect_output_stage_no_tags(self): + """测试无标记时默认进入回复阶段""" + prompt_token_ids = [200, 201, 202] # 无思考标记 + stage = self.parser.detect_output_stage(prompt_token_ids) + self.assertEqual(stage, "CONTENT_STAGE") + + # ---- Edge cases ---- + def test_batch_multiple_end_tags(self): + """测试多个结束标签(只识别第一个)""" + text = "abcxyzmore" + reasoning, content = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertEqual(content, "xyzmore") + + def test_is_reasoning_end(self): + """测试检查推理内容是否结束""" + input_ids = [200, 201, self.parser.think_end_token_id, 202] + result = self.parser.is_reasoning_end(input_ids) + self.assertTrue(result) + + input_ids = [200, 201, 202] + result = self.parser.is_reasoning_end(input_ids) + self.assertFalse(result) + + def test_extract_content_ids(self): + """测试提取 content token IDs""" + input_ids = [200, 201, self.parser.think_end_token_id, 202, 203] + result = self.parser.extract_content_ids(input_ids) + self.assertEqual(result, [202, 203]) + + input_ids = [200, 201, 202] + result = self.parser.extract_content_ids(input_ids) + self.assertEqual(result, [200, 201, 202]) + + if __name__ == "__main__": unittest.main()