Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: quiv core stream duplicate and quivr-core rag tests #2852

Merged
merged 7 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions backend/api/quivr_api/packages/quivr_core/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,21 @@ async def answer_astream(
self.supports_func_calling,
)

if self.supports_func_calling and len(answer_str) > 0:
diff_answer = answer_str[len(prev_answer) :]
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
yield parsed_chunk
if len(answer_str) > 0:
if self.supports_func_calling:
diff_answer = answer_str[len(prev_answer) :]
if len(diff_answer) > 0:
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
yield parsed_chunk
else:
yield ParsedRAGChunkResponse(
answer=answer_str,
metadata=RAGResponseMetadata(),
)

# Last chunk provides metadata
yield ParsedRAGChunkResponse(
Expand Down
2 changes: 1 addition & 1 deletion backend/core/quivr_core/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def asearch(
query, k=n_results, filter=filter, fetch_k=fetch_n_neighbors
)

return [SearchResult(chunk=d, score=s) for d, s in result]
return [SearchResult(chunk=d, distance=s) for d, s in result]

def get_chat_history(self, chat_id: UUID):
return self._chats[chat_id]
Expand Down
38 changes: 2 additions & 36 deletions backend/core/quivr_core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr
from .llm_endpoint import LLMEndpoint

from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling


class LLMEndpoint:
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
self._config = llm_config
self._llm = llm
self._supports_func_calling = model_supports_function_calling(
self._config.model
)

def get_config(self):
return self._config

@classmethod
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
try:
from langchain_openai import ChatOpenAI

_llm = ChatOpenAI(
model=config.model,
api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None,
base_url=config.llm_base_url,
)
return cls(llm=_llm, llm_config=config)

except ImportError as e:
raise ImportError(
"Please provide a valid BaseLLM or install quivr-core['base'] package"
) from e

def supports_func_calling(self) -> bool:
return self._supports_func_calling
__all__ = ["LLMEndpoint"]
37 changes: 37 additions & 0 deletions backend/core/quivr_core/llm/llm_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from langchain_core.language_models.chat_models import BaseChatModel
from pydantic.v1 import SecretStr

from quivr_core.config import LLMEndpointConfig
from quivr_core.utils import model_supports_function_calling


class LLMEndpoint:
def __init__(self, llm_config: LLMEndpointConfig, llm: BaseChatModel):
self._config = llm_config
self._llm = llm
self._supports_func_calling = model_supports_function_calling(
self._config.model
)

def get_config(self):
return self._config

@classmethod
def from_config(cls, config: LLMEndpointConfig = LLMEndpointConfig()):
try:
from langchain_openai import ChatOpenAI

_llm = ChatOpenAI(
model=config.model,
api_key=SecretStr(config.llm_api_key) if config.llm_api_key else None,
base_url=config.llm_base_url,
)
return cls(llm=_llm, llm_config=config)

except ImportError as e:
raise ImportError(
"Please provide a valid BaseLLM or install quivr-core['base'] package"
) from e

def supports_func_calling(self) -> bool:
return self._supports_func_calling
2 changes: 1 addition & 1 deletion backend/core/quivr_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ class QuivrKnowledge(BaseModel):
# NOTE: for compatibility issues with langchain <-> PydanticV1
class SearchResult(BaseModelV1):
chunk: Document
score: float
distance: float
8 changes: 6 additions & 2 deletions backend/core/quivr_core/prompts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import datetime

from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
PromptTemplate,
SystemMessagePromptTemplate,
)

# First step is to create the Rephrasing Prompt
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. Keep as much details as possible from previous messages. Keep entity names and all.
Expand Down
30 changes: 20 additions & 10 deletions backend/core/quivr_core/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
cited_answer,
)
from quivr_core.prompts import ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT
Expand Down Expand Up @@ -172,6 +173,7 @@ async def answer_astream(

rolling_message = AIMessageChunk(content="")
sources = []
prev_answer = ""

async for chunk in conversational_qa_chain.astream(
{
Expand All @@ -186,21 +188,29 @@ async def answer_astream(
sources = chunk["docs"] if "docs" in chunk else []

if "answer" in chunk:
rolling_message, parsed_chunk = parse_chunk_response(
rolling_message, answer_str = parse_chunk_response(
rolling_message,
chunk,
self.llm_endpoint.supports_func_calling(),
)

if (
self.llm_endpoint.supports_func_calling()
and len(parsed_chunk.answer) > 0
):
yield parsed_chunk
else:
yield parsed_chunk

# Last chunk provies
if len(answer_str) > 0:
if self.llm_endpoint.supports_func_calling():
diff_answer = answer_str[len(prev_answer) :]
if len(diff_answer) > 0:
parsed_chunk = ParsedRAGChunkResponse(
answer=diff_answer,
metadata=RAGResponseMetadata(),
)
prev_answer += diff_answer
yield parsed_chunk
else:
yield ParsedRAGChunkResponse(
answer=answer_str,
metadata=RAGResponseMetadata(),
)

# Last chunk provides metadata
yield ParsedRAGChunkResponse(
answer="",
metadata=get_chunk_metadata(rolling_message, sources),
Expand Down
61 changes: 19 additions & 42 deletions backend/core/quivr_core/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import logging
from typing import Any, Dict, List, Tuple, no_type_check
from typing import Any, List, Tuple, no_type_check

from langchain.schema import (
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
format_document,
)
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document

from quivr_core.models import (
ChatMessage,
ParsedRAGChunkResponse,
ParsedRAGResponse,
QuivrKnowledge,
RAGResponseMetadata,
Expand Down Expand Up @@ -43,19 +41,6 @@ def model_supports_function_calling(model_name: str):
return model_name in models_supporting_function_calls


def format_chat_history(
history: List[ChatMessage],
) -> List[Dict[str, str]]:
"""Format the chat history into a list of HumanMessage and AIMessage"""
formatted_history = []
for chat in history:
if chat.user_message:
formatted_history.append(HumanMessage(content=chat.user_message))
if chat.assistant:
formatted_history.append(AIMessage(content=chat.assistant))
return formatted_history


def format_history_to_openai_mesages(
tuple_history: List[Tuple[str, str]], system_message: str, question: str
) -> List[BaseMessage]:
Expand All @@ -73,14 +58,6 @@ def cited_answer_filter(tool):
return tool["name"] == "cited_answer"


def get_prev_message_str(msg: AIMessageChunk) -> str:
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""


def get_chunk_metadata(
msg: AIMessageChunk, sources: list[Any] = []
) -> RAGResponseMetadata:
Expand All @@ -106,39 +83,39 @@ def get_chunk_metadata(
return RAGResponseMetadata(**metadata)


def get_prev_message_str(msg: AIMessageChunk) -> str:
if msg.tool_calls:
cited_answer = next(x for x in msg.tool_calls if cited_answer_filter(x))
if "args" in cited_answer and "answer" in cited_answer["args"]:
return cited_answer["args"]["answer"]
return ""


# TODO: CONVOLUTED LOGIC !
# TODO(@aminediro): redo this
@no_type_check
def parse_chunk_response(
gathered_msg: AIMessageChunk,
rolling_msg: AIMessageChunk,
raw_chunk: dict[str, Any],
supports_func_calling: bool,
) -> Tuple[AIMessageChunk, ParsedRAGChunkResponse]:
) -> Tuple[AIMessageChunk, str]:
# Init with sources
answer_str = ""
# Get the previously parsed answer
prev_answer = get_prev_message_str(gathered_msg)

rolling_msg += raw_chunk["answer"]
if supports_func_calling:
gathered_msg += raw_chunk["answer"]
if gathered_msg.tool_calls:
if rolling_msg.tool_calls:
cited_answer = next(
x for x in gathered_msg.tool_calls if cited_answer_filter(x)
x for x in rolling_msg.tool_calls if cited_answer_filter(x)
)
if "args" in cited_answer:
gathered_args = cited_answer["args"]
if "answer" in gathered_args:
# Only send the difference between answer and response_tokens which was the previous answer
gathered_answer = gathered_args["answer"]
answer_str: str = gathered_answer[len(prev_answer) :]

return gathered_msg, ParsedRAGChunkResponse(
answer=answer_str, metadata=RAGResponseMetadata()
)
answer_str = gathered_args["answer"]
return rolling_msg, answer_str
else:
return gathered_msg, ParsedRAGChunkResponse(
answer=raw_chunk["answer"].content, metadata=RAGResponseMetadata()
)
return rolling_msg, raw_chunk["answer"].content


@no_type_check
Expand Down
Loading
Loading