Skip to content

Commit

Permalink
feat: add GPT3Tokenizer #90 (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
orangelckc committed Mar 25, 2023
1 parent 16d28d5 commit 49b7ff1
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 119 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"@types/marked": "^4.0.8",
"clipboard": "^2.0.11",
"dayjs": "^1.11.7",
"gpt3-tokenizer": "^1.1.5",
"highlight.js": "^11.7.0",
"html2canvas": "^1.4.1",
"markdown-it": "^13.0.1",
Expand Down
19 changes: 16 additions & 3 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

113 changes: 1 addition & 112 deletions src/api/openAi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import {
fetchEventSource,
type EventSourceMessage
} from '@microsoft/fetch-event-source'
import type { MessageData, SessionData } from '@/types'
import type { MessageData } from '@/types'

/**
* 获取 openai 对话消息
Expand Down Expand Up @@ -112,114 +112,3 @@ export const getOpenAICreditApi = async () => {
}
}
}

/**
* 获取 ai 回答
* @param value 消息内容
*/
export const getAiMessage = async (value?: string) => {
const apiKey = getOpenAIKey()
if (!apiKey) return

const { isThinking, sessionDataList } = storeToRefs(useSessionStore())
const { updateSessionData } = useSessionStore()

try {
const { currentRole } = useRoleStore()

if (!currentRole) return

const messages: MessageData[] = []

const { currentSession, sessionDataList } = useSessionStore()
const { isMemory } = useSettingsStore()

const lastQuestion = sessionDataList.filter((item) => item.is_ask).at(-1)

// 记忆模式,或者是第一次对话,都要生成角色描述
if (sessionDataList.length < 3 || isMemory) {
messages.push({
role: 'system',
content: currentRole.description
})
}

// 获取记忆(限制5条),往前推直到出现 is_momery 为 false 的
// TODO 应该进行限流,防止出现过多的记忆,导致token超出
const addMemory = async () => {
if (isMemory) {
// TODO: 优化 sql
const sql = `SELECT * FROM session_data WHERE session_id = '${currentSession?.id}' ORDER BY id DESC LIMIT 5;`
const memoryList = (await executeSQL(sql)) as SessionData[]

let count = 0
const arr = []
while (count < memoryList.length) {
if (!memoryList[count].is_memory) break
arr.push(JSON.parse(memoryList[count++].message as any))
}
messages.push(...arr.reverse())
}
}

// 再次生成上一次问题
if (!value) {
if (!lastQuestion) return

// 为了保证统一,这之后的内容全部删掉
const deleteSql = `DELETE FROM session_data WHERE session_id = '${lastQuestion?.session_id}' AND id >= ${lastQuestion?.id};`
await executeSQL(deleteSql)

await addMemory()
messages.push(lastQuestion?.message)
} else {
await addMemory()
messages.push({
role: 'user',
content: value
})
}

const { isThinking } = storeToRefs(useSessionStore())
const { addSessionData } = useSessionStore()

isThinking.value = true

await addSessionData({
isAsk: true,
data: messages.at(-1)!
})

await addSessionData({
isAsk: false,
data: {
role: 'assistant',
content: ''
}
})

await getOpenAIResultStreamApi(messages)

isThinking.value = false
} catch ({ message }: any) {
sessionDataList.value.at(-1)!.message.content = message as any

updateSessionData(sessionDataList.value.at(-1)!)

isThinking.value = false
}
}

/**
* 获取apiKey
*/
const getOpenAIKey = () => {
const { apiKey } = useSettingsStore()

if (!apiKey && !import.meta.env.VITE_OPEN_AI_API_KEY) {
Message.warning('请先填写 OpenAi API Key')
return false
}

return apiKey || import.meta.env.VITE_OPEN_AI_API_KEY
}
33 changes: 31 additions & 2 deletions src/components/Function/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,40 @@ import {
IconImage
} from '@arco-design/web-vue/es/icon'
import { emit } from '@tauri-apps/api/event'
import { estimateTokens, getMemoryList } from '@/utils'
const { currentRole } = storeToRefs(useRoleStore())
const { currentRole, textAreaValue } = storeToRefs(useRoleStore())
const sessionStore = useSessionStore()
const { switchSession, deleteSession, updateSessionData } = sessionStore
const { isThinking, sessionDataList, chatController } =
storeToRefs(sessionStore)
const { isMemory } = storeToRefs(useSettingsStore())
const disabled = computed(
() => isThinking.value || !sessionDataList.value.length
)
const tokenUsage = ref(0)
watch([textAreaValue, isMemory], async () => {
// 角色描述字符数
const roleTokens = estimateTokens(currentRole.value!.description)
// 输入字符数
const textAreaTokens = estimateTokens(textAreaValue.value)
// 记忆模式下额外消耗的字符数
const memroyList = await getMemoryList()
const memoryTokens = estimateTokens(
memroyList.map((item) => item.content).join('')
)
tokenUsage.value = textAreaTokens + roleTokens + memoryTokens
})
// 控制设置弹框
const modalVisible = ref(false)
const closeModal = () => {
Expand Down Expand Up @@ -91,8 +113,15 @@ const triggerScroll = () => {
<!-- TODO:把聊天对象移过来 -->
<template>
<div class="function text-6 relative flex justify-end">
<!-- 预估将要消耗的token -->
<div
class="left-1/5 text-4 -translate-1/2 absolute top-1/2"
v-if="textAreaValue.length"
>
{{ isMemory ? '记忆模式:' : '' }}预计消耗 {{ tokenUsage }} TK
</div>
<!-- 当前聊天角色对象 -->
<div class="top-50% left-50% text-4 -translate-1/2 absolute select-none">
<div class="text-4 -translate-1/2 absolute top-1/2 left-1/2">
正在与
<a-tooltip content="点我回到底部">
<span class="mark cursor-pointer" @click="triggerScroll">
Expand Down
42 changes: 40 additions & 2 deletions src/components/Session/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import { listen } from '@tauri-apps/api/event'
import MarkdownIt from 'markdown-it'
import MarkdownItHighlight from 'markdown-it-highlightjs'
import type { SessionData } from '@/types'
import { estimateTokens } from '@/utils'
import dayjs from 'dayjs'
import utc from 'dayjs/plugin/utc'
Expand Down Expand Up @@ -59,6 +61,37 @@ const handleScroll = () => {
}
}
/**
* 计算单条消息消耗的token
* @param item 单条消息
*/
// FIXME 这里的item/data类型报错搞不定,只能用any
const calcToken = (item: any) => {
// 角色描述字符数
const roleToken = estimateTokens(currentRole.value!.description)
// 消息内容字符数
const contentToken = estimateTokens(item.message.content)
// 记忆模式下额外消耗的字符数
let memoryToken = 0
if (item.is_memory) {
// 获取sessionDataList中的此条之前的最后5条消息
const memoryList = sessionDataList.value
.filter((data: any) => data.id < item.id)
.slice(-5)
memoryToken = estimateTokens(
memoryList.map((data) => data.message.content).join('')
)
}
if (item.is_ask) {
return `${roleToken + contentToken + memoryToken}TK${
item.is_memory ? '*' : ''
}`
}
return `${contentToken}TK`
}
onMounted(() => {
listen('scroll-to-bottom', () => {
isAutoScroll.value = true
Expand Down Expand Up @@ -91,8 +124,13 @@ watch([currentSession, sessionDataList], () => {
:class="item.is_ask && 'flex-row-reverse'"
:key="item.id"
>
<Avatar class="w-12!" :value="item.is_ask ? uuid : currentRole?.name" />
<div class="flex flex-col items-center gap-1">
<Avatar
class="w-12!"
:value="item.is_ask ? uuid : currentRole?.name"
/>
<span class="text-gray text-xs">{{ calcToken(item) }}</span>
</div>
<div
class="relative flex w-[calc(100%-8rem)] flex-col gap-2"
:class="item.is_ask && 'items-end'"
Expand Down
1 change: 1 addition & 0 deletions src/utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ export * from './keyMap'
export * from './saveImage'
export * from './copy'
export * from './saveMarkdown'
export * from './openai'
Loading

0 comments on commit 49b7ff1

Please sign in to comment.