Skip to content

Commit

Permalink
feat(Unplug): chatting without brain streaming (#970)
Browse files Browse the repository at this point in the history
* feat(Unplug): Adds new basic headless llm

* feat(Unplug): adds chatting without brain option when no streaming

* feat(Unplug): adds chatting without brain option when streaming
  • Loading branch information
StepanLebedevTheodo committed Aug 18, 2023
1 parent 7281fd9 commit 600ff1e
Show file tree
Hide file tree
Showing 6 changed files with 300 additions and 75 deletions.
2 changes: 2 additions & 0 deletions backend/core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .base import BaseBrainPicking
from .qa_base import QABaseBrainPicking
from .openai import OpenAIBrainPicking
from .qa_headless import HeadlessQA

__all__ = [
"BaseBrainPicking",
"QABaseBrainPicking",
"OpenAIBrainPicking",
"HeadlessQA"
]
207 changes: 207 additions & 0 deletions backend/core/llm/qa_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import asyncio
import json
from uuid import UUID

from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from repository.chat.update_message_by_id import update_message_by_id
from models.databases.supabase.chats import CreateChatHistory
from repository.chat.format_chat_history import format_chat_history
from repository.chat.get_chat_history import get_chat_history
from repository.chat.update_chat_history import update_chat_history
from repository.chat.format_chat_history import format_history_to_openai_mesages
from logger import get_logger
from models.chats import ChatQuestion
from repository.chat.get_chat_history import GetChatHistoryOutput


from pydantic import BaseModel

from typing import AsyncIterable, Awaitable, List

logger = get_logger(__name__)
SYSTEM_MESSAGE = "Your name is Quivr. You're a helpful assistant. If you don't know the answer, just say that you don't know, don't try to make up an answer."


class HeadlessQA(BaseModel):
model: str = None # type: ignore
temperature: float = 0.0
max_tokens: int = 256
user_openai_api_key: str = None # type: ignore
openai_api_key: str = None # type: ignore
streaming: bool = False
chat_id: str = None # type: ignore
callbacks: List[AsyncIteratorCallbackHandler] = None # type: ignore

def _determine_api_key(self, openai_api_key, user_openai_api_key):
"""If user provided an API key, use it."""
if user_openai_api_key is not None:
return user_openai_api_key
else:
return openai_api_key

def _determine_streaming(self, model: str, streaming: bool) -> bool:
"""If the model name allows for streaming and streaming is declared, set streaming to True."""
return streaming

def _determine_callback_array(
self, streaming
) -> List[AsyncIteratorCallbackHandler]: # pyright: ignore reportPrivateUsage=none
"""If streaming is set, set the AsyncIteratorCallbackHandler as the only callback."""
if streaming:
return [
AsyncIteratorCallbackHandler() # pyright: ignore reportPrivateUsage=none
]

def __init__(self, **data):
super().__init__(**data)

self.openai_api_key = self._determine_api_key(
self.openai_api_key, self.user_openai_api_key
)
self.streaming = self._determine_streaming(
self.model, self.streaming
) # pyright: ignore reportPrivateUsage=none
self.callbacks = self._determine_callback_array(
self.streaming
) # pyright: ignore reportPrivateUsage=none

def _create_llm(
self, model, temperature=0, streaming=False, callbacks=None
) -> BaseLLM:
"""
Determine the language model to be used.
:param model: Language model name to be used.
:param streaming: Whether to enable streaming of the model
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
return ChatOpenAI(
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=self.openai_api_key,
) # pyright: ignore reportPrivateUsage=none

def _create_prompt_template(self):
messages = [
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
return CHAT_PROMPT

def generate_answer(
self, chat_id: UUID, question: ChatQuestion
) -> GetChatHistoryOutput:
transformed_history = format_chat_history(get_chat_history(self.chat_id))
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
answering_llm = self._create_llm(
model=self.model, streaming=False, callbacks=self.callbacks
)
model_prediction = answering_llm.predict_messages(messages) # pyright: ignore reportPrivateUsage=none
answer = model_prediction.content

new_chat = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"brain_id": None,
"prompt_id": None,
}
)
)

return GetChatHistoryOutput(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": answer,
"message_time": new_chat.message_time,
"prompt_title": None,
"brain_name": None,
"message_id": new_chat.message_id,
}
)

async def generate_stream(
self, chat_id: UUID, question: ChatQuestion
) -> AsyncIterable:
callback = AsyncIteratorCallbackHandler()
self.callbacks = [callback]

transformed_history = format_chat_history(get_chat_history(self.chat_id))
messages = format_history_to_openai_mesages(transformed_history, SYSTEM_MESSAGE, question.question)
answering_llm = self._create_llm(
model=self.model, streaming=True, callbacks=self.callbacks
)

CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
headlessChain = LLMChain(llm=answering_llm, prompt=CHAT_PROMPT)

response_tokens = []

async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
except Exception as e:
logger.error(f"Caught exception: {e}")
finally:
event.set()
run = asyncio.create_task(
wrap_done(
headlessChain.acall({}),
callback.done,
),
)

streamed_chat_history = update_chat_history(
CreateChatHistory(
**{
"chat_id": chat_id,
"user_message": question.question,
"assistant": "",
"brain_id": None,
"prompt_id": None,
}
)
)

streamed_chat_history = GetChatHistoryOutput(
**{
"chat_id": str(chat_id),
"message_id": streamed_chat_history.message_id,
"message_time": streamed_chat_history.message_time,
"user_message": question.question,
"assistant": "",
"prompt_title": None,
"brain_name": None,
}
)

async for token in callback.aiter():
logger.info("Token: %s", token) # type: ignore
response_tokens.append(token) # type: ignore
streamed_chat_history.assistant = token # type: ignore
yield f"data: {json.dumps(streamed_chat_history.dict())}"

await run
assistant = "".join(response_tokens)

update_message_by_id(
message_id=str(streamed_chat_history.message_id),
user_message=question.question,
assistant=assistant,
)

class Config:
arbitrary_types_allowed = True
17 changes: 16 additions & 1 deletion backend/core/repository/chat/format_chat_history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
def format_chat_history(history) -> list[tuple[str, str]]:
from typing import List, Tuple
from langchain.schema import AIMessage, HumanMessage, SystemMessage


def format_chat_history(history) -> List[Tuple[str, str]]:
"""Format the chat history into a list of tuples (human, ai)"""

return [(chat.user_message, chat.assistant) for chat in history]


def format_history_to_openai_mesages(tuple_history: List[Tuple[str, str]], system_message: str, question: str) -> List[SystemMessage | HumanMessage | AIMessage]:
"""Format the chat history into a list of Base Messages"""
messages = []
messages.append(SystemMessage(content=system_message))
for human, ai in tuple_history:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=question))
return messages
65 changes: 39 additions & 26 deletions backend/core/routes/chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from auth import AuthBearer, get_current_user
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from llm.qa_headless import HeadlessQA
from llm.openai import OpenAIBrainPicking
from models.brains import Brain
from models.brain_entity import BrainEntity
Expand All @@ -16,9 +17,6 @@
from models.settings import LLMSettings, get_supabase_db
from models.users import User
from repository.brain.get_brain_details import get_brain_details
from repository.brain.get_default_user_brain_or_create_new import (
get_default_user_brain_or_create_new,
)
from repository.chat.create_chat import CreateChatProperties, create_chat
from repository.chat.get_chat_by_id import get_chat_by_id
from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history
Expand Down Expand Up @@ -190,17 +188,24 @@ async def create_question_handler(
check_user_limit(current_user)
LLMSettings()

if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).brain_id

gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
)
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=chat_question.model,
max_tokens=chat_question.max_tokens,
temperature=chat_question.temperature,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model,
temperature=chat_question.temperature,
max_tokens=chat_question.max_tokens,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
)

chat_answer = gpt_answer_generator.generate_answer(chat_id, chat_question)

Expand Down Expand Up @@ -259,18 +264,26 @@ async def create_stream_question_handler(
try:
logger.info(f"Streaming request for {chat_question.model}")
check_user_limit(current_user)
if not brain_id:
brain_id = get_default_user_brain_or_create_new(current_user).brain_id

gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
gpt_answer_generator: HeadlessQA | OpenAIBrainPicking
if brain_id:
gpt_answer_generator = OpenAIBrainPicking(
chat_id=str(chat_id),
model=(brain_details or chat_question).model if current_user.user_openai_api_key else "gpt-3.5-turbo",
max_tokens=(brain_details or chat_question).max_tokens if current_user.user_openai_api_key else 0,
temperature=(brain_details or chat_question).temperature if current_user.user_openai_api_key else 256,
brain_id=str(brain_id),
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
streaming=True,
)
else:
gpt_answer_generator = HeadlessQA(
model=chat_question.model if current_user.user_openai_api_key else "gpt-3.5-turbo",
temperature=chat_question.temperature if current_user.user_openai_api_key else 256,
max_tokens=chat_question.max_tokens if current_user.user_openai_api_key else 0,
user_openai_api_key=current_user.user_openai_api_key, # pyright: ignore reportPrivateUsage=none
chat_id=str(chat_id),
streaming=True,
)

print("streaming")
return StreamingResponse(
Expand Down

0 comments on commit 600ff1e

Please sign in to comment.