Skip to content

Commit

Permalink
fix(core): when use the plugin chat mode, doesn’t read chat memory af…
Browse files Browse the repository at this point in the history
…ter restarting
  • Loading branch information
dingyi222666 committed May 31, 2024
1 parent 160282b commit e27d31c
Show file tree
Hide file tree
Showing 17 changed files with 182 additions and 167 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"cross-env": "^7.0.3",
"esbuild": "^0.20.2",
"esbuild-register": "npm:@shigma/esbuild-register@^1.1.1",
"eslint": "^9.3.0",
"eslint": "^8.57.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-standard": "^17.1.0",
"eslint-plugin-import": "^2.29.1",
Expand Down
36 changes: 11 additions & 25 deletions packages/core/src/llm-core/agent/openai/index.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,12 @@
import {
AIMessage,
BaseMessage,
FunctionMessage,
ToolMessage
} from '@langchain/core/messages'
import { AIMessage, BaseMessage, FunctionMessage, ToolMessage } from '@langchain/core/messages'
import { BaseOutputParser } from '@langchain/core/output_parsers'
import { SystemPrompts } from '../../chain/base'
import {
FunctionsAgentAction,
OpenAIFunctionsAgentOutputParser,
OpenAIToolsAgentOutputParser,
ToolsAgentAction
} from './output_parser'
import { AgentAction, AgentFinish, AgentStep } from 'langchain/agents'
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder
} from '@langchain/core/prompts'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { RunnableLambda, RunnablePassthrough, RunnableSequence } from '@langchain/core/runnables'
import { StructuredTool } from '@langchain/core/tools'
import { AgentAction, AgentFinish, AgentStep } from 'langchain/agents'
import { SystemPrompts } from '../../chain/base'
import { ChatLunaChatModel } from '../../platform/model'
import {
RunnableLambda,
RunnablePassthrough,
RunnableSequence
} from '@langchain/core/runnables'
import { FunctionsAgentAction, OpenAIFunctionsAgentOutputParser, OpenAIToolsAgentOutputParser, ToolsAgentAction } from './output_parser'

/**
* Checks if the given action is a FunctionsAgentAction.
Expand Down Expand Up @@ -102,7 +84,8 @@ export function createOpenAIAgent({
const prompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder('preset'),
new MessagesPlaceholder('chat_history'),
HumanMessagePromptTemplate.fromTemplate('{input}'),
new MessagesPlaceholder('input'),
// HumanMessagePromptTemplate.fromTemplate('{input_text}'),
new MessagesPlaceholder('agent_scratchpad')
])

Expand All @@ -120,6 +103,9 @@ export function createOpenAIAgent({
agent_scratchpad: (input: { steps: AgentStep[] }) =>
_formatIntermediateSteps(input.steps),
preset: () => preset
/* // @ts-expect-error eslint-disable-next-line @typescript-eslint/naming-convention
input_text: (input: { input: BaseMessage[] }) =>
getMessageContent(input.input[0].content) */
}),
prompt,
llmWithTools,
Expand Down
20 changes: 17 additions & 3 deletions packages/core/src/llm-core/chain/browsing_chain.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
/* eslint-disable max-len */
import { Embeddings } from '@langchain/core/embeddings'
import { AIMessage, BaseMessage, SystemMessage } from '@langchain/core/messages'
import { HumanMessagePromptTemplate, MessagesPlaceholder, PromptTemplate } from '@langchain/core/prompts'
import {
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate
} from '@langchain/core/prompts'
import { StructuredTool, Tool } from '@langchain/core/tools'
import { ChainValues } from '@langchain/core/utils/types'
import { BufferMemory, ConversationSummaryMemory, VectorStoreRetrieverMemory } from 'langchain/memory'
import {
BufferMemory,
ConversationSummaryMemory,
VectorStoreRetrieverMemory
} from 'langchain/memory'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { logger } from '../..'
import { ChatLunaSaveableVectorStore } from '../model/base'
import { ChatLunaChatModel } from '../platform/model'
import { callChatHubChain, ChatHubLLMCallArg, ChatHubLLMChain, ChatHubLLMChainWrapper, SystemPrompts } from './base'
import {
callChatHubChain,
ChatHubLLMCallArg,
ChatHubLLMChain,
ChatHubLLMChainWrapper,
SystemPrompts
} from './base'
import { ChatHubBrowsingPrompt } from './prompt'

// github.com/langchain-ai/weblangchain/blob/main/nextjs/app/api/chat/stream_log/route.ts#L81
Expand Down
29 changes: 6 additions & 23 deletions packages/core/src/llm-core/chain/plugin_chat_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ export class ChatLunaPluginChain

tools: ChatHubTool[]

baseMessages: BaseMessage[] = []

currentBufferMemory: BufferMemory
baseMessages: BaseMessage[] = undefined

constructor({
historyMemory,
Expand Down Expand Up @@ -79,19 +77,6 @@ export class ChatLunaPluginChain
tools: StructuredTool[],
systemPrompts: SystemPrompts
) {
if (this.currentBufferMemory == null) {
this.currentBufferMemory = new BufferMemory({
returnMessages: true,
memoryKey: 'chat_history',
inputKey: 'input',
outputKey: 'output'
})

for (const message of this.baseMessages) {
await this.currentBufferMemory.chatHistory.addMessage(message)
}
}

return AgentExecutor.fromAgentAndTools({
tags: ['openai-functions'],
agent: createOpenAIAgent({
Expand All @@ -100,7 +85,7 @@ export class ChatLunaPluginChain
preset: systemPrompts
}),
tools,
memory: this.currentBufferMemory,
memory: undefined,
verbose: true
})
}
Expand Down Expand Up @@ -162,23 +147,20 @@ export class ChatLunaPluginChain
chat_history?: BaseMessage[]
id?: string
} = {
input: message.content
input: [message]
}

this.baseMessages =
this.baseMessages ??
(await this.historyMemory.chatHistory.getMessages())

// requests['chat_history'] = this.baseMessages
requests['chat_history'] = this.baseMessages

requests['id'] = conversationId

const [activeTools, recreate] = this._getActiveTools(
session,
(this.currentBufferMemory != null
? await this.currentBufferMemory.chatHistory.getMessages()
: this.baseMessages
).concat(message)
this.baseMessages.concat(message)
)

if (recreate || this.executor == null) {
Expand Down Expand Up @@ -233,6 +215,7 @@ export class ChatLunaPluginChain

await this.historyMemory.chatHistory.addMessage(message)
await this.historyMemory.chatHistory.addAIChatMessage(responseString)
this.baseMessages.push(message, new AIMessage(responseString))

return response
}
Expand Down
7 changes: 4 additions & 3 deletions packages/core/src/services/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ declare module 'koishi' {
chatluna: ChatLunaService
}

interface Events {
'chatluna/before-check-sender'(session: Session): Promise<boolean>
}

interface Tables {
chathub_room: ConversationRoom
chathub_room_member: ConversationRoomMemberInfo
chathub_room_group_member: ConversationRoomGroupInfo
chathub_user: ConversationRoomUserInfo
}
interface Events {
'chatluna/before-check-sender'(session: Session): Promise<boolean>
}
}
22 changes: 20 additions & 2 deletions packages/core/src/utils/string.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// 导出一个模糊查询的函数,参数是一个字符串和一个字符串数组,返回值是一个布尔值
import { BaseMessage } from '@langchain/core/messages'

export function fuzzyQuery(source: string, keywords: string[]): boolean {
// 遍历每一个关键词
for (const keyword of keywords) {
const match = source.includes(keyword)
// 如果距离小于等于最大距离,说明匹配成功,返回 true
Expand All @@ -11,3 +11,21 @@ export function fuzzyQuery(source: string, keywords: string[]): boolean {
// 如果遍历完所有关键词都没有匹配成功,返回 false
return false
}

export function getMessageContent(message: BaseMessage['content']) {
if (typeof message === 'string') {
return message
}

if (message == null) {
return ''
}

const buffer = []
for (const part of message) {
if (part.type === 'text') {
buffer.push(part.text)
}
}
return buffer.join('')
}
57 changes: 29 additions & 28 deletions packages/gemini-adapter/src/requester.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ export class GeminiRequester
jsonParser.write(rawData)
return true
},
10
0
)

let content = ''

let isVisionModel = params.model.includes('vision')
let isOldVisionModel = params.model.includes('vision')

const functionCall: ChatCompletionMessageFunctionCall & {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -149,45 +149,46 @@ export class GeminiRequester
partAsType<ChatFunctionCallingPart>(chunk)

if (messagePart.text) {
content += messagePart.text
if (params.tools != null) {
content = messagePart.text
} else {
content += messagePart.text
}

// match /w*model:
if (isVisionModel && /\s*model:\s*/.test(content)) {
isVisionModel = false
content = content.replace(/\s*model:\s*/, '')
if (isOldVisionModel && /\s*model:\s*/.test(content)) {
isOldVisionModel = false
content = messagePart.text.replace(/\s*model:\s*/, '')
}
}

if (chatFunctionCallingPart.functionCall) {
const deltaFunctionCall =
chatFunctionCallingPart.functionCall
const deltaFunctionCall = chatFunctionCallingPart.functionCall

if (deltaFunctionCall) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let args: any =
deltaFunctionCall.args?.input ??
deltaFunctionCall.args
if (deltaFunctionCall) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let args: any =
deltaFunctionCall.args?.input ?? deltaFunctionCall.args

try {
let parsedArgs = JSON.parse(args)
try {
let parsedArgs = JSON.parse(args)

if (typeof parsedArgs !== 'string') {
args = parsedArgs
}
if (typeof parsedArgs !== 'string') {
args = parsedArgs
}

parsedArgs = JSON.parse(args)
parsedArgs = JSON.parse(args)

if (typeof parsedArgs !== 'string') {
args = parsedArgs
}
} catch (e) {}
if (typeof parsedArgs !== 'string') {
args = parsedArgs
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
} catch (e) {}

functionCall.args = JSON.stringify(args)
functionCall.args = JSON.stringify(args)

functionCall.name = deltaFunctionCall.name
functionCall.name = deltaFunctionCall.name

functionCall.arguments = deltaFunctionCall.args
}
functionCall.arguments = deltaFunctionCall.args
}

try {
Expand Down
18 changes: 7 additions & 11 deletions packages/gemini-adapter/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
AIMessageChunk,
BaseMessage,
ChatMessageChunk,
HumanMessageChunk,
MessageType,
SystemMessageChunk
} from '@langchain/core/messages'
import { AIMessageChunk, BaseMessage, ChatMessageChunk, HumanMessageChunk, MessageType, SystemMessageChunk } from '@langchain/core/messages'
import { StructuredTool } from '@langchain/core/tools'
import { zodToJsonSchema } from 'zod-to-json-schema'
import {
ChatCompletionFunction,
ChatCompletionResponseMessage,
Expand All @@ -15,8 +10,6 @@ import {
ChatPart,
ChatUploadDataPart
} from './types'
import { StructuredTool } from '@langchain/core/tools'
import { zodToJsonSchema } from 'zod-to-json-schema'

export async function langchainMessageToGeminiMessage(
messages: BaseMessage[],
Expand Down Expand Up @@ -120,7 +113,10 @@ export async function langchainMessageToGeminiMessage(
]
}

if (model.includes('vision') && images != null) {
if (
(model.includes('vision') || model.includes('gemini-1.5')) &&
images != null
) {
for (const image of images) {
result.parts.push({
inline_data: {
Expand Down
Loading

0 comments on commit e27d31c

Please sign in to comment.