Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change MultiQuery Prompt, Add Hybrid Search (BM25 + Embedding), Cohere Reranker & LLM Chain Filter #247

Merged
merged 9 commits into from
Jan 10, 2024
57 changes: 51 additions & 6 deletions api/ask_astro/chains/answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.retrievers import MultiQueryRetriever
from langchain.prompts.prompt import PromptTemplate
from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever
from langchain.retrievers.document_compressors import CohereRerank, LLMChainFilter
from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever

from ask_astro.clients.weaviate_ import docsearch
from ask_astro.config import AzureOpenAIParams
from ask_astro.chains.custom_llm_filter_prompt import custom_llm_chain_filter_prompt_template
from ask_astro.clients.weaviate_ import client
from ask_astro.config import AzureOpenAIParams, CohereConfig, WeaviateConfig
from ask_astro.settings import (
CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME,
CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_TEMPERATURE,
Expand All @@ -32,19 +36,60 @@
HumanMessagePromptTemplate.from_template("{question}"),
]

hybrid_retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=WeaviateConfig.index_name,
text_key=WeaviateConfig.text_key,
attributes=WeaviateConfig.attributes,
k=WeaviateConfig.k,
alpha=WeaviateConfig.alpha,
)

# Initialize a MultiQueryRetriever using AzureChatOpenAI and Weaviate.
retriever = MultiQueryRetriever.from_llm(
user_question_rewording_prompt_template = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is
to generate 2 different versions of the given user
question to retrieve relevant documents from a vector database.
By rewording the original question, expanding on abbreviated words if there are any,
and generating multiple perspectives on the user question,
your goal is to help the user overcome some of the limitations
of distance-based similarity search. Provide these alternative
questions separated by newlines. Original question: {question}""",
)
multi_query_retriever = MultiQueryRetriever.from_llm(
llm=AzureChatOpenAI(
**AzureOpenAIParams.us_east,
deployment_name=MULTI_QUERY_RETRIEVER_DEPLOYMENT_NAME,
temperature=MULTI_QUERY_RETRIEVER_TEMPERATURE,
),
retriever=docsearch.as_retriever(),
include_original=True,
prompt=user_question_rewording_prompt_template,
retriever=hybrid_retriever,
)

# Rerank
cohere_reranker_compressor = CohereRerank(user_agent="langchain", top_n=CohereConfig.rerank_top_n)
reranker_retriever = ContextualCompressionRetriever(
base_compressor=cohere_reranker_compressor, base_retriever=multi_query_retriever
)

# GPT-3.5 to check over relevancy of the remaining documents
llm_chain_filter = LLMChainFilter.from_llm(
AzureChatOpenAI(
**AzureOpenAIParams.us_east,
deployment_name=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME,
temperature=0.0,
),
custom_llm_chain_filter_prompt_template,
)
llm_chain_filter_compression_retriever = ContextualCompressionRetriever(
base_compressor=llm_chain_filter, base_retriever=reranker_retriever
)

# Set up a ConversationalRetrievalChain to generate answers using the retriever.
answer_question_chain = ConversationalRetrievalChain(
retriever=retriever,
retriever=llm_chain_filter_compression_retriever,
return_source_documents=True,
question_generator=LLMChain(
llm=AzureChatOpenAI(
Expand Down
39 changes: 39 additions & 0 deletions api/ask_astro/chains/custom_llm_filter_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from langchain.retrievers.document_compressors.chain_filter_prompt import (
prompt_template,
)
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate


class CustomBooleanOutputParser(BaseOutputParser[bool]):
Copy link
Collaborator Author

@davidgxue davidgxue Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a changed implementation from langchain. The original code looked like this https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/output_parsers/boolean.py

I implemented this parser because of an unfixed issue on LangChain here langchain-ai/langchain#11408 where their check on the Yes/NO is way too strict and throws unwanted errors during runtime.

"""Parse the output of an LLM call to a boolean. Default to True if response not formatted correctly."""

true_val: str = "YES"
"""The string value that should be parsed as True."""
false_val: str = "NO"
"""The string value that should be parsed as False."""

def parse(self, text: str) -> bool:
"""Parse the output of an LLM call to a boolean by checking if YES/NO is contained in the output.

Args:
text: output of a language model.

Returns:
boolean

"""
cleaned_text = text.strip().upper()
return self.false_val not in cleaned_text

@property
def _type(self) -> str:
"""Snake-case string identifier for an output parser type."""
return "custom_boolean_output_parser"


custom_llm_chain_filter_prompt_template = PromptTemplate(
template=prompt_template,
input_variables=["question", "context"],
output_parser=CustomBooleanOutputParser(),
)
8 changes: 8 additions & 0 deletions api/ask_astro/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,11 @@ class WeaviateConfig:
index_name = os.environ.get("WEAVIATE_INDEX_NAME")
text_key = os.environ.get("WEAVIATE_TEXT_KEY")
attributes = os.environ.get("WEAVIATE_ATTRIBUTES", "").split(",")
k = os.environ.get("WEAVIATE_HYBRID_SEARCH_TOP_K", 100)
alpha = os.environ.get("WEAVIATE_HYBRID_SEARCH_ALPHA", 0.5)


class CohereConfig:
"""Contains the config variables for the Cohere API."""

rerank_top_n = int(os.environ.get("COHERE_RERANK_TOP_N", 10))
Loading
Loading