diff --git a/tests/ce/server/test_base_chat.py b/tests/ce/server/test_base_chat.py index cb160ca6254..57f02a25400 100644 --- a/tests/ce/server/test_base_chat.py +++ b/tests/ce/server/test_base_chat.py @@ -9,6 +9,7 @@ import json +import requests from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request @@ -271,3 +272,149 @@ def test_bad_words_filtering1(): assert word in token_list, f"'{word}' 应出现在生成结果中" print("test_bad_words_filtering1 正例验证通过") + + +def test_n_parameters(): + """ + n参数测试 n=2 + """ + # 1. 构建请求 + data = { + "stream": False, + "messages": [ + {"role": "user", "content": "牛顿是谁?"}, + ], + "max_tokens": 30, + "n": 2, + } + payload = build_request_payload(TEMPLATE, data) + + # 2. 发送请求 + resp = send_request(URL, payload).json() + + # 3. 检查返回choices数量 + choices = resp.get("choices", []) + assert len(choices) == 2, f"n参数为2,输出必须是2条数据,但实际返回 {len(choices)} 条" + + # 4. 检查每条内容开头是否符合预期 + expected_start = "牛顿是英国著名的物理学家" + for i, choice in enumerate(choices): + content = choice["message"]["content"] + print(f"Choice {i} 内容:", content) + assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配" + + print("test_n_parameters 验证通过") + + +def test_n_parameters1(): + """ + n参数测试 n=3 + """ + # 1. 构建请求 + data = { + "stream": False, + "messages": [ + {"role": "user", "content": "牛顿是谁?"}, + ], + "max_tokens": 30, + "n": 3, + } + payload = build_request_payload(TEMPLATE, data) + + # 2. 发送请求 + resp = send_request(URL, payload).json() + + # 3. 检查返回choices数量 + choices = resp.get("choices", []) + assert len(choices) == 3, f"n参数为3,输出必须是3条数据,但实际返回 {len(choices)} 条" + + # 4. 检查每条内容开头是否符合预期 + expected_start = "牛顿是英国著名的物理学家" + for i, choice in enumerate(choices): + content = choice["message"]["content"] + print(f"Choice {i} 内容:", content) + assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配" + + print("test_n_parameters 验证通过") + + +def test_n_parameters2(): + """ + n参数测试 n=6 + """ + # 1. 构建请求 + data = { + "stream": False, + "messages": [ + {"role": "user", "content": "牛顿是谁?"}, + ], + "max_tokens": 30, + "n": 6, + } + payload = build_request_payload(TEMPLATE, data) + + # 2. 发送请求 + resp = send_request(URL, payload).json() + + # 3. 检查返回choices数量 + choices = resp.get("choices", []) + assert len(choices) == 6, f"n参数为6,输出必须是6条数据,但实际返回 {len(choices)} 条" + + # 4. 检查每条内容开头是否符合预期 + expected_start = "牛顿是英国著名的物理学家" + for i, choice in enumerate(choices): + content = choice["message"]["content"] + print(f"Choice {i} 内容:", content) + assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配" + + print("test_n_parameters 验证通过") + + +def test_n_parameters_stream(): + """ + n参数测试(流式输出 n=3) + """ + data = { + "stream": True, + "messages": [ + {"role": "user", "content": "牛顿是谁?"}, + ], + "max_tokens": 30, + "n": 3, + } + payload = build_request_payload(TEMPLATE, data) + + with requests.post(URL, json=payload, stream=True) as resp: + assert resp.status_code == 200, f"请求失败,状态码 {resp.status_code}" + + # 初始化3个缓存 + partial_contents = ["", "", ""] + + for line in resp.iter_lines(decode_unicode=True): + if not line or not line.startswith("data: "): + continue + data_str = line[len("data: ") :] + if data_str.strip() == "[DONE]": + break + + try: + data_json = json.loads(data_str) + except Exception as e: + print("解析异常:", e, line) + continue + + choices = data_json.get("choices", []) + for choice in choices: + idx = choice.get("index", 0) + delta = choice.get("delta", {}).get("content", "") + if idx < len(partial_contents): + partial_contents[idx] += delta + + # 检查流式聚合结果 + assert len(partial_contents) == 3, "应产生3个流式输出" + expected_start = "牛顿是英国著名的物理学家" + for i, content in enumerate(partial_contents): + print(f"Choice {i} 最终内容:", content) + assert content.startswith(expected_start), f"第{i}条输出开头不匹配" + + print("✅ test_n_parameters_stream 验证通过")