Skip to content

Commit

Permalink
update chat stream
Browse files Browse the repository at this point in the history
  • Loading branch information
SchneeHertz committed Aug 17, 2023
1 parent 220b9f9 commit f0ae99f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 48 deletions.
100 changes: 57 additions & 43 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const { STORE_PATH, LOG_PATH, AUDIO_PATH, SPEECH_AUDIO_PATH } = require('./utils
const { getStore, setStore } = require('./modules/store.js')
const { getSpeechText } = require('./modules/whisper.js')
const { ttsPromise } = require('./modules/edge-tts.js')
const { openaiChat, openaiChatStream, openaiEmbedding } = require('./modules/common.js')
const { openaiChatStream, openaiEmbedding } = require('./modules/common.js')
const { functionAction, functionInfo, functionList } = require('./modules/functions.js')
const {config: {
DEFAULT_MODEL,
Expand Down Expand Up @@ -193,19 +193,56 @@ const resloveAdminPrompt = async ({prompt, triggerRecord})=> {
{role: 'user', content: prompt}
]

let resContent = ''
history.push({role: 'user', content: prompt})
history = _.takeRight(history, 50)
setStore('history', history)

let resTextTemp = ''
let resText = ''
let clientMessageId = nanoid()
let speakIndex = STATUS.speakIndex
STATUS.speakIndex += 1
let resFunction
let resArgument = ''
await openaiChat({
model: DEFAULT_MODEL,
messages,
functions: functionInfo
})
.then(async res=>{
resContent = res.choices[0].message.content
resFunction = res.choices[0].message?.function_call?.name
resArgument = res.choices[0].message?.function_call?.arguments
if (resFunction && resArgument) {

try {
for await (const {token, f_token} of openaiChatStream({
model: DEFAULT_MODEL,
messages,
functions: functionInfo
})) {
if (token) {
resTextTemp += token
resText += token
messageSend({
id: clientMessageId,
from,
text: resText
})
if (triggerRecord) {
if (resTextTemp.includes('\n')) {
let splitResText = resTextTemp.split('\n')
splitResText = _.compact(splitResText)
if (splitResText.length > 1) {
resTextTemp = splitResText.pop()
} else {
resTextTemp = ''
}
let pickFirstParagraph = splitResText.join('\n')
let speakText = pickFirstParagraph.replace(/[^a-zA-Z0-9一-龟]+[喵嘻捏][^a-zA-Z0-9一-龟]*$/, '喵~')
speakTextList.push({
text: speakText,
speakIndex,
})
}
}
}
let {name, arguments: arg} = f_token
if (name) resFunction = name
if (arg) resArgument += arg
}

if (!resText && resFunction && resArgument) {
messageLogAndSend({
id: nanoid(),
from,
Expand All @@ -225,43 +262,17 @@ const resloveAdminPrompt = async ({prompt, triggerRecord})=> {
console.log(e)
functionCallResult = ''
}
let functionCalling = [res.choices[0].message, {role: "function", name: resFunction, content: functionCallResult}]
let functionCalling = [
{role: "assistant", content: null, function_call: {name: resFunction, arguments: resArgument}},
{role: "function", name: resFunction, content: functionCallResult}
]
messages.push(...functionCalling)
history.push(...functionCalling)
history = _.takeRight(history, 50)
setStore('history', history)
if (functionCallResult) console.log(functionCalling)
}
})
.catch(e=>console.log(e))

let resTextTemp = ''
let resText = ''
let clientMessageId = nanoid()
let speakIndex = STATUS.speakIndex
STATUS.speakIndex += 1

try {
if (resContent && !resFunction) {
resText = resContent
messageSend({
id: clientMessageId,
from,
text: resText
})
if (triggerRecord) {
let splitResText = resContent.split('\n')
splitResText = _.compact(splitResText)
for (let paragraph of splitResText ){
let speakText = paragraph.replace(/[^a-zA-Z0-9一-龟]+[喵嘻捏][^a-zA-Z0-9一-龟]*$/, '喵~')
speakTextList.push({
text: speakText,
speakIndex
})
}
}
} else {
for await (const token of openaiChatStream({
for await (const {token} of openaiChatStream({
model: DEFAULT_MODEL,
messages,
})) {
Expand Down Expand Up @@ -291,6 +302,8 @@ const resloveAdminPrompt = async ({prompt, triggerRecord})=> {
}
}
}


if (triggerRecord) {
if (resTextTemp) {
let speakText = resTextTemp.replace(/[^a-zA-Z0-9一-龟]+[喵嘻捏][^a-zA-Z0-9一-龟]*$/, '喵~')
Expand All @@ -300,6 +313,7 @@ const resloveAdminPrompt = async ({prompt, triggerRecord})=> {
})
}
}

messageLog({
id: clientMessageId,
from,
Expand Down
22 changes: 17 additions & 5 deletions modules/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,28 @@ const openaiChat = ({ model = DEFAULT_MODEL, messages, functions, function_call
* - messages {array}: An array of message objects representing the conversation.
* @return {generator} A generator that yields tokens from the chat stream.
*/
const openaiChatStream = async function* ({ model = DEFAULT_MODEL, messages }) {
const response = await openai.chat.completions.create({
const openaiChatStream = async function* ({ model = DEFAULT_MODEL, messages, functions, function_call }) {
let response
if (functions) {
response = await openai.chat.completions.create({
model, messages, functions, function_call,
presence_penalty: 0.2,
frequency_penalty: 0.2,
stream: true,
})
} else {
response = await openai.chat.completions.create({
model, messages,
presence_penalty: 0.2,
frequency_penalty: 0.2,
stream: true,
})
})
}
for await (const part of response) {
if (_.get(part, 'choices[0].delta.finish_reason') === 'stop') return
if (['stop', 'function_call'].includes(_.get(part, 'choices[0].delta.finish_reason'))) return
const token = _.get(part, 'choices[0].delta.content')
if (token) yield token
const f_token = _.get(part, 'choices[0].delta.function_call', {})
if (token || !_.isEmpty(f_token)) yield {token, f_token}
}
}

Expand Down Expand Up @@ -136,6 +147,7 @@ const azureOpenaiEmbedding = ({ input, model = 'text-embedding-ada-002', timeout


module.exports = {
openai,
openaiChat,
openaiChatStream,
openaiEmbedding,
Expand Down

0 comments on commit f0ae99f

Please sign in to comment.