Skip to content

Commit

Permalink
支持原生的多轮对话
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinlic committed Apr 28, 2024
1 parent a695921 commit 295e69e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ Authorization: Bearer [refresh_token]
{
// 如果使用智能体请填写智能体ID到此处,否则可以乱填
"model": "glm4",
// 目前多轮对话基于消息合并实现,某些场景可能导致能力下降且token最高为4096
// 如果您想获得原生的多轮对话体验,可以传入首轮消息获得的id,来接续上下文
// "conversation_id": "65f6c28546bae1f0fbb532de",
"messages": [
{
"role": "user",
Expand All @@ -292,6 +295,7 @@ Authorization: Bearer [refresh_token]
响应数据:
```json
{
// conversation_id,你可以传入到下一轮对话来接续上下文
"id": "65f6c28546bae1f0fbb532de",
"model": "glm4",
"object": "chat.completion",
Expand Down
70 changes: 47 additions & 23 deletions src/api/controllers/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ async function createCompletion(
messages: any[],
refreshToken: string,
assistantId = DEFAULT_ASSISTANT_ID,
refConvId = '',
retryCount = 0
) {
return (async () => {
Expand All @@ -184,13 +185,14 @@ async function createCompletion(
: [];

// 请求流
console.log(refConvId)
const token = await acquireToken(refreshToken);
const result = await axios.post(
"https://chatglm.cn/chatglm/backend-api/assistant/stream",
{
assistant_id: assistantId,
conversation_id: "",
messages: messagesPrepare(messages, refs),
conversation_id: refConvId,
messages: messagesPrepare(messages, refs, !!refConvId),
meta_data: {
channel: "",
draft_id: "",
Expand Down Expand Up @@ -232,7 +234,7 @@ async function createCompletion(

// 异步移除会话
removeConversation(answer.id, refreshToken, assistantId).catch((err) =>
console.error(err)
!refConvId && console.error(err)
);

return answer;
Expand All @@ -246,6 +248,7 @@ async function createCompletion(
messages,
refreshToken,
assistantId,
refConvId,
retryCount + 1
);
})();
Expand All @@ -266,6 +269,7 @@ async function createCompletionStream(
messages: any[],
refreshToken: string,
assistantId = DEFAULT_ASSISTANT_ID,
refConvId = '',
retryCount = 0
) {
return (async () => {
Expand All @@ -285,8 +289,8 @@ async function createCompletionStream(
`https://chatglm.cn/chatglm/backend-api/assistant/stream`,
{
assistant_id: assistantId,
conversation_id: "",
messages: messagesPrepare(messages, refs),
conversation_id: refConvId,
messages: messagesPrepare(messages, refs, !!refConvId),
meta_data: {
channel: "",
draft_id: "",
Expand Down Expand Up @@ -349,7 +353,7 @@ async function createCompletionStream(
);
// 流传输结束后异步移除会话
removeConversation(convId, refreshToken, assistantId).catch((err) =>
console.error(err)
!refConvId && console.error(err)
);
});
})().catch((err) => {
Expand All @@ -362,6 +366,7 @@ async function createCompletionStream(
messages,
refreshToken,
assistantId,
refConvId,
retryCount + 1
);
})();
Expand Down Expand Up @@ -488,8 +493,10 @@ function extractRefFileUrls(messages: any[]) {
* 由于接口只取第一条消息,此处会将多条消息合并为一条,实现多轮对话效果
*
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param refs 参考文件列表
* @param isRefConv 是否为引用会话
*/
function messagesPrepare(messages: any[], refs: any[]) {
function messagesPrepare(messages: any[], refs: any[], isRefConv = false) {
// 检查最新消息是否含有"type": "image_url"或"type": "file",如果有则注入消息
let latestMessage = messages[messages.length - 1];
let hasFileOrImage =
Expand All @@ -514,36 +521,53 @@ function messagesPrepare(messages: any[], refs: any[]) {
// logger.info("注入提升尾部消息注意力system prompt");
}

const content = (
messages.reduce((content, message) => {
const role = message.role
.replace("system", "<|sytstem|>")
.replace("assistant", "<|assistant|>")
.replace("user", "<|user|>");
let content;
if(isRefConv || messages.length < 2) {
content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return (
message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + (`${role}\n` + v["text"] || "") + "\n";
return _content + (v["text"] || "") + "\n";
}, content)
);
}
return (content += `${role}\n${message.content}\n`);
}, "") + "<|assistant|>\n"
)
// 移除MD图像URL避免幻觉
.replace(/\!\[.+\]\(.+\)/g, "")
// 移除临时路径避免在新会话引发幻觉
.replace(/\/mnt\/data\/.+/g, "");
return content + `${message.content}\n`;
}, "");
logger.info("\n透传内容:\n" + content);
}
else {
content = (
messages.reduce((content, message) => {
const role = message.role
.replace("system", "<|sytstem|>")
.replace("assistant", "<|assistant|>")
.replace("user", "<|user|>");
if (_.isArray(message.content)) {
return (
message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + (`${role}\n` + v["text"] || "") + "\n";
}, content)
);
}
return (content += `${role}\n${message.content}\n`);
}, "") + "<|assistant|>\n"
)
// 移除MD图像URL避免幻觉
.replace(/\!\[.+\]\(.+\)/g, "")
// 移除临时路径避免在新会话引发幻觉
.replace(/\/mnt\/data\/.+/g, "");
logger.info("\n对话合并:\n" + content);
}

const fileRefs = refs.filter((ref) => !ref.width && !ref.height);
const imageRefs = refs
.filter((ref) => ref.width || ref.height)
.map((ref) => {
ref.image_url = ref.file_url;
return ref;
});
content
logger.info("\n对话合并:\n" + content);
return [
{
role: "user",
Expand Down
11 changes: 6 additions & 5 deletions src/api/routes/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@ export default {

'/completions': async (request: Request) => {
request
.validate('body.conversation_id', v => _.isUndefined(v) || _.isString(v))
.validate('body.messages', _.isArray)
.validate('headers.authorization', _.isString)
// refresh_token切分
const tokens = chat.tokenSplit(request.headers.authorization);
// 随机挑选一个refresh_token
const token = _.sample(tokens);
const messages = request.body.messages;
const assistantId = /^[a-z0-9]{24,}$/.test(request.body.model) ? request.body.model : undefined
if (request.body.stream) {
const stream = await chat.createCompletionStream(request.body.messages, token, assistantId);
const { model, conversation_id: convId, messages, stream } = request.body;
const assistantId = /^[a-z0-9]{24,}$/.test(model) ? model : undefined
if (stream) {
const stream = await chat.createCompletionStream(messages, token, assistantId, convId);
return new Response(stream, {
type: "text/event-stream"
});
}
else
return await chat.createCompletion(messages, token, assistantId);
return await chat.createCompletion(messages, token, assistantId, convId);
}

}
Expand Down

0 comments on commit 295e69e

Please sign in to comment.