From 06969eca073001a8896c390165d026773ef0faab Mon Sep 17 00:00:00 2001 From: Kaichen <276126767@qq.com> Date: Thu, 16 Mar 2023 23:08:22 +0100 Subject: [PATCH] feat: add stream output and memory mode --- package.json | 3 +- src/api/index.ts | 55 +++++++++++++- .../Function/components/SettingsModal.vue | 4 +- src/components/Input/components/RoleList.vue | 2 - src/components/Session/index.vue | 10 ++- src/stores/role.ts | 13 ++-- src/stores/session.ts | 23 ++++-- src/types/sql.d.ts | 2 +- src/utils/openai.ts | 71 +++++++++++-------- 9 files changed, 130 insertions(+), 53 deletions(-) diff --git a/package.json b/package.json index cad64e2..58b7333 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ ] }, "dependencies": { + "@microsoft/fetch-event-source": "^2.0.1", "@multiavatar/multiavatar": "^1.0.7", "@tauri-apps/api": "^1.2.0", "pinia": "^2.0.33", @@ -54,4 +55,4 @@ "vite": "^4.0.0", "vue-tsc": "^1.0.11" } -} +} \ No newline at end of file diff --git a/src/api/index.ts b/src/api/index.ts index 25978b8..6176443 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -2,7 +2,11 @@ import { fetch, Body } from '@tauri-apps/api/http' import { dialogErrorMessage } from '@/utils' import type { FetchOptions } from '@tauri-apps/api/http' import type { RecordData } from '@/types' -import { useSettingsStore } from '@/stores' +import { useSessionStore, useSettingsStore } from '@/stores' +import { + fetchEventSource, + type EventSourceMessage +} from '@microsoft/fetch-event-source' /** * 请求总入口 @@ -52,10 +56,57 @@ export const getOpenAIResultApi = async (messages: RecordData[]) => { body: Body.json({ model: 'gpt-3.5-turbo-0301', messages, - stream: true + temperature: 0.6, + stream: false }), headers: { Authorization: `Bearer ${apiKey || import.meta.env.VITE_OPEN_AI_API_KEY}` } }) } + +/** + * 获取 openai 对话消息(流) + * @param messages 消息列表 + */ +export const getOpenAIResultStream = async (messages: RecordData[]) => { + if (!messages.length) return + + const { apiKey } = useSettingsStore() + const { addSessionData } = useSessionStore() + const { streamReply } = storeToRefs(useSessionStore()) + streamReply.value = '' + + await fetchEventSource(import.meta.env.VITE_OPEN_AI_URL, { + method: 'POST', + body: JSON.stringify({ + model: 'gpt-3.5-turbo-0301', + messages, + temperature: 0.6, + stream: true + }), + headers: { + Authorization: `Bearer ${apiKey || import.meta.env.VITE_OPEN_AI_API_KEY}`, + 'Content-Type': 'application/json', + Accept: 'application/json' + }, + onmessage(msg: EventSourceMessage) { + if (msg.data !== '[DONE]') { + const { choices } = JSON.parse(msg.data) + + if (!choices[0].delta.content) return + streamReply.value += choices[0].delta.content + } + }, + onclose() { + const res: RecordData = { + role: 'assistant', + content: streamReply.value + } + addSessionData(false, '', res) + }, + onerror(err: any) { + throw new Error('流输出出错:', err) + } + }) +} diff --git a/src/components/Function/components/SettingsModal.vue b/src/components/Function/components/SettingsModal.vue index 42010ec..2ebba10 100644 --- a/src/components/Function/components/SettingsModal.vue +++ b/src/components/Function/components/SettingsModal.vue @@ -4,7 +4,7 @@ import ShortcutKey from './ShortcutKey.vue' defineProps<{ visible: boolean; setVisible: () => void }>() -const { apiKey, autoStart } = storeToRefs(useSettingsStore()) +const { apiKey, autoStart, isMemory } = storeToRefs(useSettingsStore()) diff --git a/src/stores/role.ts b/src/stores/role.ts index 245d227..513484d 100644 --- a/src/stores/role.ts +++ b/src/stores/role.ts @@ -23,8 +23,6 @@ export const useRoleStore = defineStore( defaultRole.value = [] roleList.value = result.map((item) => ({ ...item, isEdit: false })) - - changeCurrentRole() } const getFilterRoleList = (value: string) => { @@ -40,14 +38,13 @@ export const useRoleStore = defineStore( filterList.value.length = 0 } - const changeCurrentRole = () => { + const changeCurrentRole = async () => { const { currentSession } = useSessionStore() - const findRole = roleList.value.find( - (role) => role.id === currentSession?.role_id - ) - console.log('currentSession', currentSession, findRole) - currentRole.value = findRole ?? roleList.value[0] + const sql = `SELECT * FROM role WHERE id = ${currentSession?.role_id};` + const findRole = (await executeSQL(sql)) as RolePayload[] + + currentRole.value = findRole[0] ?? roleList.value[0] } const addRole = async (payload: RolePayload) => { diff --git a/src/stores/session.ts b/src/stores/session.ts index cb85c3a..ccf5563 100644 --- a/src/stores/session.ts +++ b/src/stores/session.ts @@ -1,6 +1,7 @@ import { executeSQL } from '@/sqls' import type { SessionPayload, SessionData, RecordData } from '@/types' import { useRoleStore } from './role' +import { useSettingsStore } from './settings' // TODO: 无记忆对话和有记忆对话 // 用来管理当前会话的状态 @@ -15,6 +16,8 @@ export const useSessionStore = defineStore( const sessionList = ref([]) // 请求发送状态 const isThinking = ref(false) + // 流式回复 + const streamReply = ref('') const getSessionList = async () => { const sql = @@ -57,26 +60,30 @@ export const useSessionStore = defineStore( return result.length > 0 } - const { changeCurrentRole, currentRole } = useRoleStore() - // TODO: 是否为记忆对话 // TODO: messageType从 types 中取到 const addSessionData = async ( isAsk: boolean, messageType: string, - data: RecordData[] + data: RecordData ) => { if (!currentSession.value) return // 检查会话是否已经存在 const isExist = await checkSessionExist() + const { currentRole } = useRoleStore() + if (!isExist) { - const sql = `INSERT INTO session (id, title, role_id) VALUES ('${currentSession.value.id}', '${data[1].content}', '${currentRole?.id}');` + const sql = `INSERT INTO session (id, title, role_id) VALUES ('${currentSession.value.id}', '${data.content}', '${currentRole?.id}');` executeSQL(sql) } - const sql = `INSERT INTO session_data (session_id, is_ask, messages) VALUES ( - '${currentSession.value.id}','${isAsk}', '${JSON.stringify(data)}');` + const { isMemory } = useSettingsStore() + + const sql = `INSERT INTO session_data (session_id, is_ask, is_memory, message) VALUES ( + '${currentSession.value.id}', ${isAsk}, ${isMemory}, '${JSON.stringify( + data + )}');` executeSQL(sql) getSessionData() @@ -88,6 +95,9 @@ export const useSessionStore = defineStore( else { currentSession.value = session } + + const { changeCurrentRole } = useRoleStore() + changeCurrentRole() getSessionData() } @@ -98,6 +108,7 @@ export const useSessionStore = defineStore( currentSession, sessionDataList, isThinking, + streamReply, sessionList, addSessionData, switchSession, diff --git a/src/types/sql.d.ts b/src/types/sql.d.ts index e9598a7..36f785d 100644 --- a/src/types/sql.d.ts +++ b/src/types/sql.d.ts @@ -21,7 +21,7 @@ export interface SessionData { is_ask: boolean is_memory: boolean message_type?: 'text' | 'image' | 'voice' - messages: RecordData[] + message: RecordData time?: string } diff --git a/src/utils/openai.ts b/src/utils/openai.ts index e6cc27e..e5ab031 100644 --- a/src/utils/openai.ts +++ b/src/utils/openai.ts @@ -1,39 +1,59 @@ -import { getOpenAIResultApi } from '@/api' +import { getOpenAIResultStream } from '@/api' import { executeSQL } from '@/sqls' -import { useSessionStore, useRoleStore } from '@/stores' -import { RecordData } from '@/types' +import { useSessionStore, useRoleStore, useSettingsStore } from '@/stores' +import { RecordData, SessionData } from '@/types' export const getAiMessage = async (value?: string) => { const { currentRole } = useRoleStore() if (!currentRole) return - let messages: RecordData[] + const messages: RecordData[] = [] + + const { sessionDataList, currentSession } = useSessionStore() + const { isMemory } = useSettingsStore() + + const lastQuestion = sessionDataList.filter((item) => item.is_ask).at(-1) + + // 记忆模式,或者是第一次对话,都要生成角色描述 + if (sessionDataList.length < 3 || isMemory) + messages.push({ + role: 'system', + content: currentRole.description + }) + + // 获取记忆(限制5条),往前推直到出现is_momery为false的 + // TODO 应该进行限流,防止出现过多的记忆,导致token超出 + const addMemory = async () => { + if (isMemory) { + const sql = `SELECT * FROM session_data WHERE session_id = '${currentSession?.id}' ORDER BY id DESC LIMIT 5;` + const memoryList = (await executeSQL(sql)) as SessionData[] + + let count = 0 + const arr = [] + while (count < memoryList.length) { + if (!memoryList[count].is_memory) break + arr.push(JSON.parse(memoryList[count++].message as any)) + } + messages.push(...arr.reverse()) + } + } // 再次生成上一次问题 if (!value) { - const { sessionDataList } = useSessionStore() - const lastQuestion = sessionDataList.filter((item) => item.is_ask).at(-1) - if (!lastQuestion) return // 为了保证统一,这之后的内容全部删掉 const deleteSql = `DELETE FROM session_data WHERE session_id = '${lastQuestion?.session_id}' AND id >= ${lastQuestion?.id};` - await executeSQL(deleteSql) - messages = JSON.parse(lastQuestion?.messages as any) || [] + await addMemory() + messages.push(JSON.parse(lastQuestion?.message as any)) } else { - // TODO 这里可以优化,如何携带上一次的对话内容 - messages = [ - { - role: 'system', - content: currentRole.description - }, - { - role: 'user', - content: value - } - ] + await addMemory() + messages.push({ + role: 'user', + content: value + }) } const { isThinking } = storeToRefs(useSessionStore()) @@ -41,15 +61,8 @@ export const getAiMessage = async (value?: string) => { isThinking.value = true - addSessionData(true, '', messages) - const result = await getOpenAIResultApi(messages) + addSessionData(true, '', messages.at(-1)!) + await getOpenAIResultStream(messages) isThinking.value = false - - console.log('result', result) - - if (!result) return - - // TODO 处理流式输出的结果 - addSessionData(false, '', result.message) }