Skip to content
147 changes: 147 additions & 0 deletions tests/ce/server/test_base_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import json

import requests
from core import TEMPLATE, URL, build_request_payload, get_token_list, send_request


Expand Down Expand Up @@ -271,3 +272,149 @@ def test_bad_words_filtering1():
assert word in token_list, f"'{word}' 应出现在生成结果中"

print("test_bad_words_filtering1 正例验证通过")


def test_n_parameters():
"""
n参数测试 n=2
"""
# 1. 构建请求
data = {
"stream": False,
"messages": [
{"role": "user", "content": "牛顿是谁?"},
],
"max_tokens": 30,
"n": 2,
}
payload = build_request_payload(TEMPLATE, data)

# 2. 发送请求
resp = send_request(URL, payload).json()

# 3. 检查返回choices数量
choices = resp.get("choices", [])
assert len(choices) == 2, f"n参数为2,输出必须是2条数据,但实际返回 {len(choices)} 条"

# 4. 检查每条内容开头是否符合预期
expected_start = "牛顿是英国著名的物理学家"
for i, choice in enumerate(choices):
content = choice["message"]["content"]
print(f"Choice {i} 内容:", content)
assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配"

print("test_n_parameters 验证通过")


def test_n_parameters1():
"""
n参数测试 n=3
"""
# 1. 构建请求
data = {
"stream": False,
"messages": [
{"role": "user", "content": "牛顿是谁?"},
],
"max_tokens": 30,
"n": 3,
}
payload = build_request_payload(TEMPLATE, data)

# 2. 发送请求
resp = send_request(URL, payload).json()

# 3. 检查返回choices数量
choices = resp.get("choices", [])
assert len(choices) == 3, f"n参数为3,输出必须是3条数据,但实际返回 {len(choices)} 条"

# 4. 检查每条内容开头是否符合预期
expected_start = "牛顿是英国著名的物理学家"
for i, choice in enumerate(choices):
content = choice["message"]["content"]
print(f"Choice {i} 内容:", content)
assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配"

print("test_n_parameters 验证通过")


def test_n_parameters2():
"""
n参数测试 n=6
"""
# 1. 构建请求
data = {
"stream": False,
"messages": [
{"role": "user", "content": "牛顿是谁?"},
],
"max_tokens": 30,
"n": 6,
}
payload = build_request_payload(TEMPLATE, data)

# 2. 发送请求
resp = send_request(URL, payload).json()

# 3. 检查返回choices数量
choices = resp.get("choices", [])
assert len(choices) == 6, f"n参数为6,输出必须是6条数据,但实际返回 {len(choices)} 条"

# 4. 检查每条内容开头是否符合预期
expected_start = "牛顿是英国著名的物理学家"
for i, choice in enumerate(choices):
content = choice["message"]["content"]
print(f"Choice {i} 内容:", content)
assert content.startswith(expected_start), f"第{i}条输出内容开头不匹配"

print("test_n_parameters 验证通过")


def test_n_parameters_stream():
"""
n参数测试(流式输出 n=3)
"""
data = {
"stream": True,
"messages": [
{"role": "user", "content": "牛顿是谁?"},
],
"max_tokens": 30,
"n": 3,
}
payload = build_request_payload(TEMPLATE, data)

with requests.post(URL, json=payload, stream=True) as resp:
assert resp.status_code == 200, f"请求失败,状态码 {resp.status_code}"

# 初始化3个缓存
partial_contents = ["", "", ""]

for line in resp.iter_lines(decode_unicode=True):
if not line or not line.startswith("data: "):
continue
data_str = line[len("data: ") :]
if data_str.strip() == "[DONE]":
break

try:
data_json = json.loads(data_str)
except Exception as e:
print("解析异常:", e, line)
continue

choices = data_json.get("choices", [])
for choice in choices:
idx = choice.get("index", 0)
delta = choice.get("delta", {}).get("content", "")
if idx < len(partial_contents):
partial_contents[idx] += delta

# 检查流式聚合结果
assert len(partial_contents) == 3, "应产生3个流式输出"
expected_start = "牛顿是英国著名的物理学家"
for i, content in enumerate(partial_contents):
print(f"Choice {i} 最终内容:", content)
assert content.startswith(expected_start), f"第{i}条输出开头不匹配"

print("✅ test_n_parameters_stream 验证通过")
Loading