Skip to content

Commit

Permalink
支持阿里云ticket、支持原生多轮对话、优化绘图
Browse files Browse the repository at this point in the history
  • Loading branch information
Vinlic committed May 10, 2024
1 parent b3bfb22 commit 574e398
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 63 deletions.
160 changes: 105 additions & 55 deletions src/api/controllers/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ const FILE_MAX_SIZE = 100 * 1024 * 1024;
*
* 在对话流传输完毕后移除会话,避免创建的会话出现在用户的对话列表中
*
* @param ticket login_tongyi_ticket值
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
*/
async function removeConversation(convId: string, ticket: string) {
const result = await axios.post(
Expand All @@ -73,13 +73,15 @@ async function removeConversation(convId: string, ticket: string) {
*
* @param model 模型名称
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param ticket login_tongyi_ticket值
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
* @param refConvId 引用的会话ID
* @param retryCount 重试次数
*/
async function createCompletion(
model = MODEL_NAME,
messages: any[],
ticket: string,
refConvId = '',
retryCount = 0
) {
let session: http2.ClientHttp2Session;
Expand All @@ -94,6 +96,10 @@ async function createCompletion(
)
: [];

// 如果引用对话ID不正确则重置引用
if (!/[0-9a-z]{32}/.test(refConvId))
refConvId = '';

// 请求流
const session: http2.ClientHttp2Session = await new Promise(
(resolve, reject) => {
Expand All @@ -102,6 +108,7 @@ async function createCompletion(
session.on("error", reject);
}
);
const [sessionId, parentMsgId = ''] = refConvId.split('-');
const req = session.request({
":method": "POST",
":path": "/dialog/conversation",
Expand All @@ -118,10 +125,13 @@ async function createCompletion(
action: "next",
userAction: "chat",
requestId: util.uuid(false),
sessionId: "",
sessionId,
sessionType: "text_chat",
parentMsgId: "",
contents: messagesPrepare(messages, refs),
parentMsgId,
params: {
"fileUploadBatchId": util.uuid()
},
contents: messagesPrepare(messages, refs, !!refConvId),
})
);
req.setEncoding("utf8");
Expand All @@ -144,7 +154,7 @@ async function createCompletion(
logger.warn(`Try again after ${RETRY_DELAY / 1000}s...`);
return (async () => {
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY));
return createCompletion(model, messages, ticket, retryCount + 1);
return createCompletion(model, messages, ticket, refConvId, retryCount + 1);
})();
}
throw err;
Expand All @@ -156,14 +166,15 @@ async function createCompletion(
*
* @param model 模型名称
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param ticket login_tongyi_ticket值
* @param useSearch 是否开启联网搜索
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
* @param refConvId 引用的会话ID
* @param retryCount 重试次数
*/
async function createCompletionStream(
model = MODEL_NAME,
messages: any[],
ticket: string,
refConvId = '',
retryCount = 0
) {
let session: http2.ClientHttp2Session;
Expand All @@ -178,12 +189,17 @@ async function createCompletionStream(
)
: [];

// 如果引用对话ID不正确则重置引用
if (!/[0-9a-z]{32}/.test(refConvId))
refConvId = ''

// 请求流
session = await new Promise((resolve, reject) => {
const session = http2.connect("https://qianwen.biz.aliyun.com");
session.on("connect", () => resolve(session));
session.on("error", reject);
});
const [sessionId, parentMsgId = ''] = refConvId.split('-');
const req = session.request({
":method": "POST",
":path": "/dialog/conversation",
Expand All @@ -200,10 +216,13 @@ async function createCompletionStream(
action: "next",
userAction: "chat",
requestId: util.uuid(false),
sessionId: "",
sessionId,
sessionType: "text_chat",
parentMsgId: "",
contents: messagesPrepare(messages, refs),
parentMsgId,
params: {
"fileUploadBatchId": util.uuid()
},
contents: messagesPrepare(messages, refs, !!refConvId),
})
);
req.setEncoding("utf8");
Expand All @@ -225,7 +244,7 @@ async function createCompletionStream(
logger.warn(`Try again after ${RETRY_DELAY / 1000}s...`);
return (async () => {
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY));
return createCompletionStream(model, messages, ticket, retryCount + 1);
return createCompletionStream(model, messages, ticket, refConvId, retryCount + 1);
})();
}
throw err;
Expand All @@ -238,6 +257,7 @@ async function generateImages(
ticket: string,
retryCount = 0
) {
let session: http2.ClientHttp2Session;
return (async () => {
const messages = [
{
Expand All @@ -250,48 +270,56 @@ async function generateImages(
content:
"好的,我将确保处理用户一切的绘图请求并且不会随意更改绘图请求。",
},
{ role: "user", content: prompt },
{ role: "user", content: prompt.indexOf('画') == -1 ? `请画:${prompt}` : prompt },
];
// 创建会话并获得流
const result = await axios.post(
"https://qianwen.biz.aliyun.com/dialog/conversation",
{
// 请求流
const session: http2.ClientHttp2Session = await new Promise(
(resolve, reject) => {
const session = http2.connect("https://qianwen.biz.aliyun.com");
session.on("connect", () => resolve(session));
session.on("error", reject);
}
);
const req = session.request({
":method": "POST",
":path": "/dialog/conversation",
"Content-Type": "application/json",
Cookie: generateCookie(ticket),
...FAKE_HEADERS,
Accept: "text/event-stream",
});
req.setTimeout(120000);
req.write(
JSON.stringify({
mode: "chat",
model: "",
action: "next",
mode: "chat",
userAction: "chat",
requestId: util.uuid(false),
sessionId: "",
sessionType: "text_chat",
parentMsgId: "",
contents: messagesPrepare(messages),
},
{
headers: {
Cookie: generateCookie(ticket),
...FAKE_HEADERS,
Accept: "text/event-stream",
params: {
"fileUploadBatchId": util.uuid()
},
timeout: 120000,
validateStatus: () => true,
responseType: "stream",
}
contents: messagesPrepare(messages),
})
);
req.setEncoding("utf8");
const streamStartTime = util.timestamp();
// 接收流为输出文本
const { convId, imageUrls } = await receiveImages(result.data);
const { convId, imageUrls } = await receiveImages(req);
session.close();
logger.success(
`Stream has completed transfer ${util.timestamp() - streamStartTime}ms`
);

// 异步移除会话,如果消息不合规,此操作可能会抛出数据库错误异常,请忽略
removeConversation(convId, ticket).catch((err) => console.error(err));

if (imageUrls.length == 0)
throw new APIException(EX.API_IMAGE_GENERATION_FAILED);

return imageUrls;
})().catch((err) => {
session && session.close();
if (retryCount < MAX_RETRY_COUNT) {
logger.error(`Stream response error: ${err.message}`);
logger.warn(`Try again after ${RETRY_DELAY / 1000}s...`);
Expand Down Expand Up @@ -349,25 +377,44 @@ function extractRefFileUrls(messages: any[]) {
* user:新消息
*
* @param messages 参考gpt系列消息格式,多轮对话请完整提供上下文
* @param refs 参考文件列表
* @param isRefConv 是否为引用会话
*/
function messagesPrepare(messages: any[], refs: any[] = []) {
const content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + `<|im_start|>${message.role || "user"}\n${v["text"] || ""}<|im_end|>\n`;
}, content);
}
return (content += `<|im_start|>${message.role || "user"}\n${
message.content
}<|im_end|>\n`);
}, "").replace(/\!\[.*\]\(.+\)/g, "");
logger.info("\n对话合并:\n" + content);
function messagesPrepare(messages: any[], refs: any[] = [], isRefConv = false) {
let content;
if (isRefConv || messages.length < 2) {
content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return (
message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + (v["text"] || "") + "\n";
}, content)
);
}
return content + `${message.content}\n`;
}, "");
logger.info("\n透传内容:\n" + content);
}
else {
content = messages.reduce((content, message) => {
if (_.isArray(message.content)) {
return message.content.reduce((_content, v) => {
if (!_.isObject(v) || v["type"] != "text") return _content;
return _content + `<|im_start|>${message.role || "user"}\n${v["text"] || ""}<|im_end|>\n`;
}, content);
}
return (content += `<|im_start|>${message.role || "user"}\n${
message.content
}<|im_end|>\n`);
}, "").replace(/\!\[.*\]\(.+\)/g, "");
logger.info("\n对话合并:\n" + content);
}
return [
{
role: "user",
contentType: "text",
content,
contentType: "text",
role: "user",
},
...refs
];
Expand Down Expand Up @@ -418,7 +465,8 @@ async function receiveStream(stream: any): Promise<any> {
const result = _.attempt(() => JSON.parse(event.data));
if (_.isError(result))
throw new Error(`Stream response invalid: ${event.data}`);
if (!data.id && result.sessionId) data.id = result.sessionId;
if (!data.id && result.sessionId && result.msgId)
data.id = `${result.sessionId}-${result.msgId}`;
const text = (result.contents || []).reduce((str, part) => {
const { contentType, role, content } = part;
if (contentType != "text" && contentType != "text2image") return str;
Expand Down Expand Up @@ -532,7 +580,7 @@ function createTransStream(stream: any, endCallback?: Function) {
if (chunk && result.contentType == "text") {
content += chunk;
const data = `data: ${JSON.stringify({
id: result.sessionId,
id: `${result.sessionId}-${result.msgId}`,
model: MODEL_NAME,
object: "chat.completion.chunk",
choices: [
Expand All @@ -549,7 +597,7 @@ function createTransStream(stream: any, endCallback?: Function) {
if (result.errorCode)
delta.content += `服务暂时不可用,第三方响应错误:${result.errorCode}`;
const data = `data: ${JSON.stringify({
id: result.sessionId,
id: `${result.sessionId}-${result.msgId}`,
model: MODEL_NAME,
object: "chat.completion.chunk",
choices: [
Expand Down Expand Up @@ -644,13 +692,14 @@ async function receiveImages(
stream.on("data", (buffer) => parser.feed(buffer.toString()));
stream.once("error", (err) => reject(err));
stream.once("close", () => resolve({ convId, imageUrls }));
stream.end();
});
}

/**
* 获取上传参数
*
* @param ticket login_tongyi_ticket值
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
*/
async function acquireUploadParams(ticket: string) {
const result = await axios.post(
Expand Down Expand Up @@ -700,7 +749,7 @@ async function checkFileUrl(fileUrl: string) {
* 上传文件
*
* @param fileUrl 文件URL
* @param ticket login_tongyi_ticket值
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
*/
async function uploadFile(fileUrl: string, ticket: string) {
// 预检查远程文件URL可用性
Expand Down Expand Up @@ -870,11 +919,12 @@ function tokenSplit(authorization: string) {
/**
* 生成Cookies
*
* @param ticket login_tongyi_ticket值
* @param ticket login_tongyi_ticket或login_aliyunid_ticket
*/
function generateCookie(ticket: string) {
return [
`login_tongyi_ticket=${ticket}`,
`${ticket.length > 100 ? 'login_aliyunid_ticket' : 'login_tongyi_ticket'}=${ticket}`,
'aliyun_choice=intl',
"_samesite_flag_=true",
`t=${util.uuid(false)}`,
"channel=oug71n2fX3Jd5ualEfKACRvnsceUtpjUC5jHBpfWnSOXKhkvBNuSO8bG3v4HHjCgB722h7LqbHkB6sAxf3OvgA%3D%3D",
Expand Down
17 changes: 9 additions & 8 deletions src/api/routes/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@ import _ from "lodash";
import Request from "@/lib/request/Request.ts";
import Response from "@/lib/response/Response.ts";
import chat from "@/api/controllers/chat.ts";
import logger from "@/lib/logger.ts";

export default {
prefix: "/v1/chat",

post: {
"/completions": async (request: Request) => {
request
.validate('body.conversation_id', v => _.isUndefined(v) || _.isString(v))
.validate("body.messages", _.isArray)
.validate("headers.authorization", _.isString);
// refresh_token切分
// ticket切分
const tokens = chat.tokenSplit(request.headers.authorization);
// 随机挑选一个refresh_token
// 随机挑选一个ticket
const token = _.sample(tokens);
const model = request.body.model;
const messages = request.body.messages;
if (request.body.stream) {
const { model, conversation_id: convId, messages, stream } = request.body;
if (stream) {
const stream = await chat.createCompletionStream(
model,
messages,
token
token,
convId
);
return new Response(stream, {
type: "text/event-stream",
Expand All @@ -32,7 +32,8 @@ export default {
return await chat.createCompletion(
model,
messages,
token
token,
convId
);
},
},
Expand Down

0 comments on commit 574e398

Please sign in to comment.