Skip to content

Commit

Permalink
支持原生多轮对话
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinlic committed Apr 28, 2024
1 parent 2aa6465 commit 7cc6033
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
74 changes: 51 additions & 23 deletions src/api/controllers/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,10 @@ async function promptSnippetSubmit(query: string, refreshToken: string) {
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param refreshToken 用于刷新access_token的refresh_token
* @param useSearch 是否开启联网搜索
* @param refConvId 引用会话ID
* @param retryCount 重试次数
*/
async function createCompletion(model = MODEL_NAME, messages: any[], refreshToken: string, useSearch = true, retryCount = 0) {
async function createCompletion(model = MODEL_NAME, messages: any[], refreshToken: string, useSearch = true, refConvId?: string, retryCount = 0) {
return (async () => {
logger.info(messages);

Expand All @@ -252,14 +253,19 @@ async function createCompletion(model = MODEL_NAME, messages: any[], refreshToke
.catch(err => logger.error(err));

// 创建会话
const convId = await createConversation("未命名会话", refreshToken);
const convId = /[0-9a-zA-Z]{20}/.test(refConvId) ? refConvId : await createConversation("未命名会话", refreshToken);

// 请求流
const {
accessToken,
userId
} = await acquireToken(refreshToken);
const sendMessages = messagesPrepare(messages);
const sendMessages = messagesPrepare(messages, !!refConvId);
console.log(convId, {
messages: sendMessages,
refs,
use_search: useSearch
});
const result = await axios.post(`https://kimi.moonshot.cn/api/chat/${convId}/completion/stream`, {
messages: sendMessages,
refs,
Expand All @@ -268,6 +274,7 @@ async function createCompletion(model = MODEL_NAME, messages: any[], refreshToke
headers: {
Authorization: `Bearer ${accessToken}`,
Referer: `https://kimi.moonshot.cn/chat/${convId}`,
'Priority': 'u=1, i',
'X-Traffic-Id': userId,
...FAKE_HEADERS
},
Expand All @@ -283,7 +290,8 @@ async function createCompletion(model = MODEL_NAME, messages: any[], refreshToke
logger.success(`Stream has completed transfer ${util.timestamp() - streamStartTime}ms`);

// 异步移除会话,如果消息不合规,此操作可能会抛出数据库错误异常,请忽略
removeConversation(convId, refreshToken)
// 如果引用会话将不会清除,因为我们不知道什么时候你会结束会话
!refConvId && removeConversation(convId, refreshToken)
.catch(err => console.error(err));
promptSnippetSubmit(sendMessages[0].content, refreshToken)
.catch(err => console.error(err));
Expand All @@ -296,7 +304,7 @@ async function createCompletion(model = MODEL_NAME, messages: any[], refreshToke
logger.warn(`Try again after ${RETRY_DELAY / 1000}s...`);
return (async () => {
await new Promise(resolve => setTimeout(resolve, RETRY_DELAY));
return createCompletion(model, messages, refreshToken, useSearch, retryCount + 1);
return createCompletion(model, messages, refreshToken, useSearch, refConvId, retryCount + 1);
})();
}
throw err;
Expand All @@ -310,9 +318,10 @@ async function createCompletion(model = MODEL_NAME, messages: any[], refreshToke
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param refreshToken 用于刷新access_token的refresh_token
* @param useSearch 是否开启联网搜索
* @param refConvId 引用会话ID
* @param retryCount 重试次数
*/
async function createCompletionStream(model = MODEL_NAME, messages: any[], refreshToken: string, useSearch = true, retryCount = 0) {
async function createCompletionStream(model = MODEL_NAME, messages: any[], refreshToken: string, useSearch = true, refConvId?: string, retryCount = 0) {
return (async () => {
logger.info(messages);

Expand All @@ -325,14 +334,14 @@ async function createCompletionStream(model = MODEL_NAME, messages: any[], refre
.catch(err => logger.error(err));

// 创建会话
const convId = await createConversation("未命名会话", refreshToken);
const convId = /[0-9a-zA-Z]{20}/.test(refConvId) ? refConvId : await createConversation("未命名会话", refreshToken);

// 请求流
const {
accessToken,
userId
} = await acquireToken(refreshToken);
const sendMessages = messagesPrepare(messages);
const sendMessages = messagesPrepare(messages, !!refConvId);
const result = await axios.post(`https://kimi.moonshot.cn/api/chat/${convId}/completion/stream`, {
messages: sendMessages,
refs,
Expand All @@ -343,6 +352,7 @@ async function createCompletionStream(model = MODEL_NAME, messages: any[], refre
headers: {
Authorization: `Bearer ${accessToken}`,
Referer: `https://kimi.moonshot.cn/chat/${convId}`,
'Priority': 'u=1, i',
'X-Traffic-Id': userId,
...FAKE_HEADERS
},
Expand All @@ -354,7 +364,8 @@ async function createCompletionStream(model = MODEL_NAME, messages: any[], refre
return createTransStream(model, convId, result.data, () => {
logger.success(`Stream has completed transfer ${util.timestamp() - streamStartTime}ms`);
// 流传输结束后异步移除会话,如果消息不合规,此操作可能会抛出数据库错误异常,请忽略
removeConversation(convId, refreshToken)
// 如果引用会话将不会清除,因为我们不知道什么时候你会结束会话
!refConvId && removeConversation(convId, refreshToken)
.catch(err => console.error(err));
promptSnippetSubmit(sendMessages[0].content, refreshToken)
.catch(err => console.error(err));
Expand All @@ -366,7 +377,7 @@ async function createCompletionStream(model = MODEL_NAME, messages: any[], refre
logger.warn(`Try again after ${RETRY_DELAY / 1000}s...`);
return (async () => {
await new Promise(resolve => setTimeout(resolve, RETRY_DELAY));
return createCompletionStream(model, messages, refreshToken, useSearch, retryCount + 1);
return createCompletionStream(model, messages, refreshToken, useSearch, refConvId, retryCount + 1);
})();
}
throw err;
Expand Down Expand Up @@ -447,8 +458,9 @@ function extractRefFileUrls(messages: any[]) {
* user:新消息
*
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param isRefConv 是否为引用会话
*/
function messagesPrepare(messages: any[]) {
function messagesPrepare(messages: any[], isRefConv = false) {
// 注入消息提升注意力
let latestMessage = messages[messages.length - 1];
let hasFileOrImage = Array.isArray(latestMessage.content)
Expand All @@ -472,16 +484,32 @@ function messagesPrepare(messages: any[]) {
}
}

const content = messages.reduce((content, message) => {
if (Array.isArray(message.content)) {
return message.content.reduce((_content, v) => {
if (!_.isObject(v) || v['type'] != 'text') return _content;
return _content + `${message.role || "user"}:${v["text"] || ""}\n`;
}, content);
}
return content += `${message.role || "user"}:${message.role == 'user' ? wrapUrlsToTags(message.content) : message.content}\n`;
}, '');
logger.info("\n对话合并:\n" + content);
let content;
if (isRefConv || messages.length < 2) {
content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return message.content.reduce((_content, v) => {
if (!_.isObject(v) || v['type'] != 'text') return _content;
return _content + `${v["text"] || ""}\n`;
}, content);
}
return content += `${message.role == 'user' ? wrapUrlsToTags(message.content) : message.content}\n`;
}, '')
logger.info("\n透传内容:\n" + content);
}
else {
content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return message.content.reduce((_content, v) => {
if (!_.isObject(v) || v['type'] != 'text') return _content;
return _content + `${message.role || "user"}:${v["text"] || ""}\n`;
}, content);
}
return content += `${message.role || "user"}:${message.role == 'user' ? wrapUrlsToTags(message.content) : message.content}\n`;
}, '')
logger.info("\n对话合并:\n" + content);
}

return [
{ role: 'user', content }
]
Expand Down Expand Up @@ -648,8 +676,8 @@ async function uploadFile(fileUrl: string, refreshToken: string) {
...FAKE_HEADERS
}
})
.then(() => resolve(true))
.catch(() => resolve(false));
.then(() => resolve(true))
.catch(() => resolve(false));
});
}

Expand Down
10 changes: 5 additions & 5 deletions src/api/routes/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@ export default {

'/completions': async (request: Request) => {
request
.validate('body.conversation_id', v => _.isUndefined(v) || _.isString(v))
.validate('body.messages', _.isArray)
.validate('headers.authorization', _.isString)
// refresh_token切分
const tokens = chat.tokenSplit(request.headers.authorization);
// 随机挑选一个refresh_token
const token = _.sample(tokens);
const model = request.body.model;
const messages = request.body.messages;
if (request.body.stream) {
const stream = await chat.createCompletionStream(model, messages, token, request.body.use_search);
const { model, conversation_id: convId, messages, stream, use_search } = request.body;
if (stream) {
const stream = await chat.createCompletionStream(model, messages, token, use_search, convId);
return new Response(stream, {
type: "text/event-stream"
});
}
else
return await chat.createCompletion(model, messages, token, request.body.use_search);
return await chat.createCompletion(model, messages, token, use_search, convId);
}

}
Expand Down

0 comments on commit 7cc6033

Please sign in to comment.