Skip to content

Commit

Permalink
feat(chat): added streaming (#808)
Browse files Browse the repository at this point in the history
* feat(tmp): added streaming

* feat(streaming): implemented by changing order
  • Loading branch information
StanGirard committed Jul 31, 2023
1 parent db40f3c commit 3166d08
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 61 deletions.
19 changes: 4 additions & 15 deletions backend/core/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from abc import abstractmethod
from typing import AsyncIterable, List

from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.llms.base import LLM
from logger import get_logger
from models.settings import BrainSettings # Importing settings related to the 'brain'
from pydantic import BaseModel # For data validation and settings management
from utils.constants import streaming_compatible_models

logger = get_logger(__name__)

Expand All @@ -33,7 +31,7 @@ class BaseBrainPicking(BaseModel):

openai_api_key: str = None # pyright: ignore reportPrivateUsage=none
callbacks: List[
AsyncCallbackHandler
AsyncIteratorCallbackHandler
] = None # pyright: ignore reportPrivateUsage=none

def _determine_api_key(self, openai_api_key, user_openai_api_key):
Expand All @@ -45,23 +43,14 @@ def _determine_api_key(self, openai_api_key, user_openai_api_key):

def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
if model in streaming_compatible_models and streaming:
return True
if model not in streaming_compatible_models and streaming:
logger.warning(
f"Streaming is not compatible with {model}. Streaming will be set to False."
)
return False
else:
return False

return streaming
def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler # pyright: ignore reportPrivateUsage=none
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]

def __init__(self, **data):
Expand Down
1 change: 1 addition & 0 deletions backend/core/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ def _create_llm(self, model, streaming=False, callbacks=None) -> BaseLLM:
temperature=self.temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
) # pyright: ignore reportPrivateUsage=none
50 changes: 29 additions & 21 deletions backend/core/llm/qa_base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio
import json
from abc import abstractmethod, abstractproperty
from typing import AsyncIterable, Awaitable

from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from logger import get_logger
from models.chat import ChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.update_message_by_id import update_message_by_id
from supabase.client import Client, create_client
from vectorstore.supabase import CustomSupabaseVectorStore
from langchain.chat_models import ChatOpenAI
from repository.chat.update_message_by_id import update_message_by_id
import json

from .base import BaseBrainPicking
from .prompts.CONDENSE_PROMPT import CONDENSE_QUESTION_PROMPT
Expand Down Expand Up @@ -60,31 +61,31 @@ def supabase_client(self) -> Client:

@property
def vector_store(self) -> CustomSupabaseVectorStore:

return CustomSupabaseVectorStore(
self.supabase_client,
self.embeddings,
table_name="vectors",
brain_id=self.brain_id,
)

@property
def question_llm(self):
return self._create_llm(model=self.model, streaming=False)

@property
def doc_llm(self):
return self._create_llm(
model=self.model, streaming=self.streaming, callbacks=self.callbacks
model=self.model, streaming=True, callbacks=self.callbacks
)

@property
def question_generator(self) -> LLMChain:
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT)
return LLMChain(llm=self.question_llm, prompt=CONDENSE_QUESTION_PROMPT, verbose=True)

@property
def doc_chain(self) -> LLMChain:
return load_qa_chain(
llm=self.doc_llm, chain_type="stuff"
llm=self.doc_llm, chain_type="stuff", verbose=True
) # pyright: ignore reportPrivateUsage=none

@property
Expand Down Expand Up @@ -170,10 +171,20 @@ async def generate_stream(self, question: str) -> AsyncIterable:
:param question: The question
:return: An async iterable which generates the answer.
"""

history = get_chat_history(self.chat_id)
callback = self.callbacks[0]

callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]
model = ChatOpenAI(
streaming=True,
verbose=True,
callbacks=[callback],
)
llm = ChatOpenAI(temperature=0)
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
doc_chain = load_qa_chain(model, chain_type="stuff")
qa = ConversationalRetrievalChain(
retriever=self.vector_store.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator)
transformed_history = []

# Format the chat history into a list of tuples (human, ai)
Expand All @@ -183,23 +194,21 @@ async def generate_stream(self, question: str) -> AsyncIterable:
response_tokens = []

# Wrap an awaitable with a event to signal when it's done or an exception is raised.

async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()

task = asyncio.create_task(
wrap_done(
self.qa._acall_chain( # pyright: ignore reportPrivateUsage=none
self.qa, question, transformed_history
),
callback.done, # pyright: ignore reportPrivateUsage=none
)
)

# Begin a task that runs in the background.

run = asyncio.create_task(wrap_done(
qa.acall({"question": question, "chat_history": transformed_history}),
callback.done,
))

streamed_chat_history = update_chat_history(
chat_id=self.chat_id,
user_message=question,
Expand All @@ -216,8 +225,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):

yield f"data: {json.dumps(streamed_chat_history.to_dict())}"

await task

await run
# Join the tokens to create the assistant's response
assistant = "".join(response_tokens)

Expand Down
25 changes: 8 additions & 17 deletions backend/core/routes/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from http.client import HTTPException
from typing import List
from uuid import UUID
from venv import logger

from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, Query, Request
Expand All @@ -18,9 +19,6 @@
from repository.chat.get_chat_history import get_chat_history
from repository.chat.get_user_chats import get_user_chats
from repository.chat.update_chat import ChatUpdatableProperties, update_chat
from utils.constants import (
streaming_compatible_models,
)

chat_router = APIRouter()

Expand Down Expand Up @@ -228,33 +226,26 @@ async def create_stream_question_handler(
current_user: User = Depends(get_current_user),
) -> StreamingResponse:
# TODO: check if the user has access to the brain
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id

if chat_question.model not in streaming_compatible_models:
# Forward the request to the none streaming endpoint
return await create_question_handler(
request,
chat_question,
chat_id,
current_user, # pyright: ignore reportPrivateUsage=none
)

try:
user_openai_api_key = request.headers.get("Openai-Api-Key")
streaming = True
logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user)
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).id


gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=streaming,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)

print("streaming")
return StreamingResponse(
gpt_answer_generator.generate_stream( # pyright: ignore reportPrivateUsage=none
chat_question.question
Expand Down
3 changes: 2 additions & 1 deletion backend/core/vectorstore/supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def __init__(
def similarity_search(
self,
query: str,
table: str = "match_vectors",
k: int = 6,
table: str = "match_vectors",
threshold: float = 0.5,
**kwargs: Any
) -> List[Document]:

vectors = self._embedding.embed_documents([query])
query_embedding = vectors[0]
res = self._client.rpc(
Expand Down
12 changes: 6 additions & 6 deletions frontend/app/chat/[chatId]/hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import { useChatContext } from "@/lib/context/ChatProvider/hooks/useChatContext"
import { useToast } from "@/lib/hooks";
import { useEventTracking } from "@/services/analytics/useEventTracking";

import { useQuestion } from "./useQuestion";
import { ChatQuestion } from "../types";
import { useQuestion } from "./useQuestion";



// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export const useChat = () => {
Expand Down Expand Up @@ -68,11 +70,9 @@ export const useChat = () => {

void track("QUESTION_ASKED");

if (chatQuestion.model === "gpt-3.5-turbo") {
await addStreamQuestion(currentChatId, chatQuestion);
} else {
await addQuestionToModel(currentChatId, chatQuestion);
}

await addStreamQuestion(currentChatId, chatQuestion);


callback?.();
} catch (error) {
Expand Down
2 changes: 1 addition & 1 deletion frontend/app/chat/[chatId]/hooks/useQuestion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export const useQuestion = (): UseChatService => {
Accept: "text/event-stream",
};
const body = JSON.stringify(chatQuestion);

console.log("Calling API...");
try {
const response = await fetchInstance.post(
`/chat/${chatId}/question/stream?brain_id=${currentBrain.id}`,
Expand Down
1 change: 1 addition & 0 deletions frontend/lib/context/BrainConfigProvider/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type BrainConfigContextType = {
// export const openAiModels = ["gpt-3.5-turbo", "gpt-4"] as const; ## TODO activate GPT4 when not in demo mode

export const openAiModels = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
] as const;
Expand Down

0 comments on commit 3166d08

Please sign in to comment.