Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import path from 'path'
import { AIMessage } from '@langchain/core/messages'
import { Document } from '@langchain/core/documents'
import { RunnableLambda } from '@langchain/core/runnables'
import { ICommonObject, IMessage, INodeData, IServerSideEventStreamer } from '../../../src/Interface'

const { nodeClass: ConversationalRetrievalQAChainNode } = require('./ConversationalRetrievalQAChain')
const langChainCorePackagePath = require.resolve('@langchain/core/package.json')
const { FakeListChatModel } = require(path.join(path.dirname(langChainCorePackagePath), 'dist/utils/testing/chat_models.cjs'))

type HarnessOptions = {
history?: IMessage[]
responses?: string[]
returnSourceDocuments?: boolean
shouldStreamResponse?: boolean
}

const createSseStreamer = () =>
({
streamStartEvent: jest.fn(),
streamTokenEvent: jest.fn(),
streamSourceDocumentsEvent: jest.fn(),
streamEndEvent: jest.fn()
} as unknown as jest.Mocked<IServerSideEventStreamer>)

const createMemory = (history: IMessage[] = []) => ({
getChatMessages: jest.fn().mockResolvedValue(history),
addChatMessages: jest.fn().mockResolvedValue(undefined),
clearChatMessages: jest.fn().mockResolvedValue(undefined)
})

const createStreamingCapableModel = (responses: string[], shouldStreamResponse: boolean) => {
const model = new FakeListChatModel({ responses })

if (shouldStreamResponse) {
let responseIndex = 0
model._generate = async (
_messages: unknown,
_options: unknown,
runManager: { handleLLMNewToken?: (token: string) => Promise<void> }
) => {
const response = responses[Math.min(responseIndex, responses.length - 1)] ?? ''
responseIndex += 1

for (const token of response) {
await runManager?.handleLLMNewToken?.(token)
}

return {
generations: [
{
text: response,
message: new AIMessage(response)
}
],
llmOutput: {}
}
}
}

return model
}

const createHarness = ({
history = [],
responses = ['Answer from retrieved docs'],
returnSourceDocuments = true,
shouldStreamResponse = true
}: HarnessOptions = {}) => {
const model = createStreamingCapableModel(responses, shouldStreamResponse)
const retriever = RunnableLambda.from(async () => [
new Document({ pageContent: 'Relevant context from the retriever', metadata: { id: 'doc-1' } })
])
const memory = createMemory(history)
const sseStreamer = createSseStreamer()
const node = new ConversationalRetrievalQAChainNode({ sessionId: 'test-session' })

const nodeData = {
id: 'node-1',
label: 'Conversational Retrieval QA Chain',
name: 'conversationalRetrievalQAChain',
type: 'ConversationalRetrievalQAChain',
icon: 'qa.svg',
version: 3,
category: 'Chains',
baseClasses: ['ConversationalRetrievalQAChain'],
inputs: {
model,
vectorStoreRetriever: retriever,
memory,
returnSourceDocuments
}
} as unknown as INodeData

const options: ICommonObject = {
shouldStreamResponse,
sseStreamer,
chatId: 'chat-1',
logger: {
verbose: jest.fn(),
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn()
}
}

return { node, nodeData, options, memory, model, retriever, sseStreamer }
}

describe('ConversationalRetrievalQAChain test harness', () => {
it('runs the node with focused streaming callback test doubles', async () => {
const harness = createHarness()

const result = await harness.node.run(harness.nodeData, 'What does the document say?', harness.options)
await new Promise((resolve) => setTimeout(resolve, 0))

expect(result).toEqual({
text: 'Answer from retrieved docs',
sourceDocuments: [
expect.objectContaining({
pageContent: 'Relevant context from the retriever',
metadata: { id: 'doc-1' }
})
]
})
expect(harness.memory.getChatMessages).toHaveBeenCalledWith('test-session', false, undefined)
expect(harness.memory.addChatMessages).toHaveBeenCalledWith(
[
{
text: 'What does the document say?',
type: 'userMessage'
},
{
text: 'Answer from retrieved docs',
type: 'apiMessage'
}
],
'test-session'
)
expect(harness.sseStreamer.streamEndEvent).toHaveBeenCalledWith('chat-1')
})

it('streams progressive answer tokens during a streaming run', async () => {
const harness = createHarness({
responses: ['Progressive streaming answer']
})

await harness.node.run(harness.nodeData, 'Stream the answer progressively', harness.options)
await new Promise((resolve) => setTimeout(resolve, 0))

const streamedTokens = harness.sseStreamer.streamTokenEvent.mock.calls.map(([, token]) => token)

expect(harness.sseStreamer.streamStartEvent).toHaveBeenCalled()
expect(streamedTokens.length).toBeGreaterThan(1)
expect(streamedTokens.join('')).toBe('Progressive streaming answer')
})

it('emits source documents and does not leak condensed-question text in streamed tokens', async () => {
const condensedQuestionText = 'CONDENSED QUESTION SHOULD NOT STREAM'
const finalAnswerText = 'Final streamed answer from retrieved docs'
const harness = createHarness({
history: [
{ message: 'Earlier user question', type: 'userMessage' },
{ message: 'Earlier assistant answer', type: 'apiMessage' }
],
responses: [condensedQuestionText, finalAnswerText],
returnSourceDocuments: true,
shouldStreamResponse: true
})

await harness.node.run(harness.nodeData, 'Follow-up question', harness.options)
await new Promise((resolve) => setTimeout(resolve, 0))

const streamedTokens = harness.sseStreamer.streamTokenEvent.mock.calls.map(([, token]) => token)
const streamedText = streamedTokens.join('')

expect(streamedText).toBe(finalAnswerText)
expect(streamedText).not.toContain(condensedQuestionText)
expect(harness.sseStreamer.streamSourceDocumentsEvent).toHaveBeenCalledWith('chat-1', [
expect.objectContaining({
pageContent: 'Relevant context from the retriever',
metadata: { id: 'doc-1' }
})
])
})
})
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { applyPatch } from 'fast-json-patch'
import { DataSource } from 'typeorm'
import { BaseLanguageModel } from '@langchain/core/language_models/base'
import { BaseRetriever } from '@langchain/core/retrievers'
Expand All @@ -12,7 +11,7 @@ import type { Document } from '@langchain/core/documents'
import { BufferMemoryInput } from '@langchain/classic/memory'
import { ConversationalRetrievalQAChain } from '@langchain/classic/chains'
import { getBaseClasses, mapChatMessageToBaseMessage, createTextOnlyOutputParser } from '../../../src/utils'
import { ConsoleCallbackHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import {
FlowiseMemory,
ICommonObject,
Expand All @@ -31,6 +30,11 @@ type RetrievalChainInput = {
question: string
}

type ConversationalRetrievalQAResult = {
text: string
sourceDocuments: Document[]
}

const sourceRunnableName = 'FindDocs'

class ConversationalRetrievalQAChain_Chains implements INode {
Expand Down Expand Up @@ -220,6 +224,10 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)

const history = ((await memory.getChatMessages(this.sessionId, false, prependMessages)) as IMessage[]) ?? []
const hasChatHistory = history.length > 0
// CustomChainHandler decrements skipK in handleLLMStart before checking tokens,
// so skipping the first LLM call requires an initial value of 2.
const skipK = hasChatHistory ? 2 : 0

const loggerHandler = new ConsoleCallbackHandler(options.logger, options?.orgId)
const additionalCallback = await additionalCallbacks(nodeData, options)
Expand All @@ -230,60 +238,17 @@ class ConversationalRetrievalQAChain_Chains implements INode {
callbacks.push(new LCConsoleCallbackHandler())
}

const stream = answerChain.streamLog(
{ question: input, chat_history: history },
{ callbacks },
{
includeNames: [sourceRunnableName]
}
)

let streamedResponse: Record<string, any> = {}
let sourceDocuments: ICommonObject[] = []
let text = ''
let isStreamingStarted = false

for await (const chunk of stream) {
streamedResponse = applyPatch(streamedResponse, chunk.ops).newDocument

if (streamedResponse.final_output) {
text = streamedResponse.final_output?.output
if (Array.isArray(streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output)) {
sourceDocuments = streamedResponse?.logs?.[sourceRunnableName]?.final_output?.output
if (shouldStreamResponse && returnSourceDocuments) {
if (sseStreamer) {
sseStreamer.streamSourceDocumentsEvent(chatId, sourceDocuments)
}
}
}
if (shouldStreamResponse && sseStreamer) {
sseStreamer.streamEndEvent(chatId)
}
}

if (
Array.isArray(streamedResponse?.streamed_output) &&
streamedResponse?.streamed_output.length &&
!streamedResponse.final_output
) {
const token = streamedResponse.streamed_output[streamedResponse.streamed_output.length - 1]

if (!isStreamingStarted) {
isStreamingStarted = true
if (shouldStreamResponse) {
if (sseStreamer) {
sseStreamer.streamStartEvent(chatId, token)
}
}
}
if (shouldStreamResponse) {
if (sseStreamer) {
sseStreamer.streamTokenEvent(chatId, token)
}
}
}
if (shouldStreamResponse) {
callbacks.push(new CustomChainHandler(sseStreamer, chatId, skipK, returnSourceDocuments))
}

const result = (await answerChain.invoke(
{ question: input, chat_history: history },
{ callbacks }
)) as ConversationalRetrievalQAResult
const text = result?.text ?? ''
const sourceDocuments = (result?.sourceDocuments ?? []) as ICommonObject[]

await memory.addChatMessages(
[
{
Expand Down Expand Up @@ -359,24 +324,37 @@ const createChain = (
) => {
const retrieverChain = createRetrieverChain(llm, retriever, rephrasePrompt)

const context = RunnableMap.from({
context: RunnableSequence.from([
({ question, chat_history }) => ({
question,
chat_history: formatChatHistoryAsString(chat_history)
const context = RunnableSequence.from([
RunnableMap.from({
sourceDocuments: RunnableSequence.from([
({ question, chat_history }) => ({
question,
chat_history: formatChatHistoryAsString(chat_history)
}),
retrieverChain
]).withConfig({ runName: sourceRunnableName }),
question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
runName: 'Itemgetter:question'
}),
retrieverChain,
RunnableLambda.from(formatDocs).withConfig({
runName: 'FormatDocumentChunks'
chat_history: RunnableLambda.from((input: RetrievalChainInput) => input.chat_history).withConfig({
runName: 'Itemgetter:chat_history'
})
]),
question: RunnableLambda.from((input: RetrievalChainInput) => input.question).withConfig({
runName: 'Itemgetter:question'
}),
chat_history: RunnableLambda.from((input: RetrievalChainInput) => input.chat_history).withConfig({
runName: 'Itemgetter:chat_history'
RunnableMap.from({
sourceDocuments: RunnableLambda.from((input: { sourceDocuments: Document[] }) => input.sourceDocuments).withConfig({
runName: 'Itemgetter:sourceDocuments'
}),
context: RunnableLambda.from((input: { sourceDocuments: Document[] }) => formatDocs(input.sourceDocuments)).withConfig({
runName: 'FormatDocumentChunks'
}),
question: RunnableLambda.from((input: { question: string }) => input.question).withConfig({
runName: 'Itemgetter:question'
}),
chat_history: RunnableLambda.from((input: { chat_history: BaseMessage[] }) => input.chat_history).withConfig({
runName: 'Itemgetter:chat_history'
})
})
}).withConfig({ tags: ['RetrieveDocs'] })
]).withConfig({ tags: ['RetrieveDocs'] })

const prompt = ChatPromptTemplate.fromMessages([
['system', responsePrompt],
Expand All @@ -398,7 +376,12 @@ const createChain = (
})
},
context,
responseSynthesizerChain
RunnableMap.from({
text: responseSynthesizerChain,
sourceDocuments: RunnableLambda.from((input: { sourceDocuments: Document[] }) => input.sourceDocuments).withConfig({
runName: 'Itemgetter:sourceDocuments'
})
})
])

return conversationalQAChain
Expand Down
Loading