From bc71397e9b716c50b7a84c8efcc8d5282db147fc Mon Sep 17 00:00:00 2001 From: David Xue Date: Tue, 2 Jan 2024 14:24:06 -0800 Subject: [PATCH 1/7] Add hybrid search and cohere reranker --- api/ask_astro/chains/answer_question.py | 47 ++++++++-- api/poetry.lock | 112 +++++++++++++++++++++++- api/pyproject.toml | 1 + 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/api/ask_astro/chains/answer_question.py b/api/ask_astro/chains/answer_question.py index fe025ac9..ef63003b 100644 --- a/api/ask_astro/chains/answer_question.py +++ b/api/ask_astro/chains/answer_question.py @@ -11,10 +11,13 @@ MessagesPlaceholder, SystemMessagePromptTemplate, ) -from langchain.retrievers import MultiQueryRetriever +from langchain.prompts.prompt import PromptTemplate +from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever +from langchain.retrievers.document_compressors import CohereRerank +from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever -from ask_astro.clients.weaviate_ import docsearch -from ask_astro.config import AzureOpenAIParams +from ask_astro.clients.weaviate_ import client +from ask_astro.config import AzureOpenAIParams, WeaviateConfig from ask_astro.settings import ( CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME, CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_TEMPERATURE, @@ -32,19 +35,51 @@ HumanMessagePromptTemplate.from_template("{question}"), ] +hybrid_retriever = WeaviateHybridSearchRetriever( + client=client, + index_name=WeaviateConfig.index_name, + text_key=WeaviateConfig.text_key, + attributes=WeaviateConfig.attributes, + create_schema_if_missing="false", + k=100, + alpha=0.5, +) + +compressor = CohereRerank(user_agent="langchain", top_n=4) +compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=hybrid_retriever) + +user_question_rewroding_prompt_template = PromptTemplate( + input_variables=["question"], + template="""You are an AI language model assistant. Your task is + to generate 2 different versions of the given user + question to retrieve relevant documents from a vector database. + By rewording the original question, expanding on abbreviated words if there are any, + and generating multiple perspectives on the user question, + your goal is to help the user overcome some of the limitations + of distance-based similarity search. Provide these alternative + questions separated by newlines. Original question: {question}""", +) + # Initialize a MultiQueryRetriever using AzureChatOpenAI and Weaviate. -retriever = MultiQueryRetriever.from_llm( +multi_query_retriever = MultiQueryRetriever.from_llm( llm=AzureChatOpenAI( **AzureOpenAIParams.us_east, deployment_name=MULTI_QUERY_RETRIEVER_DEPLOYMENT_NAME, temperature=MULTI_QUERY_RETRIEVER_TEMPERATURE, ), - retriever=docsearch.as_retriever(), + include_original=True, + prompt=user_question_rewroding_prompt_template, + retriever=compression_retriever, +) +final_compressor = CohereRerank(user_agent="langchain", top_n=8) + +final_compression_retriever = ContextualCompressionRetriever( + base_compressor=final_compressor, base_retriever=multi_query_retriever ) # Set up a ConversationalRetrievalChain to generate answers using the retriever. answer_question_chain = ConversationalRetrievalChain( - retriever=retriever, + retriever=final_compression_retriever, return_source_documents=True, question_generator=LLMChain( llm=AzureChatOpenAI( diff --git a/api/poetry.lock b/api/poetry.lock index 638ed76b..fa7c71c8 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -240,6 +240,17 @@ files = [ [package.dependencies] cryptography = ">=3.2" +[[package]] +name = "backoff" +version = "2.2.1" +description = "Function decoration for backoff and retry" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, + {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, +] + [[package]] name = "brotli" version = "1.1.0" @@ -590,6 +601,25 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "platform_system == \"Windows\""} +[[package]] +name = "cohere" +version = "4.39" +description = "Python SDK for the Cohere API" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "cohere-4.39-py3-none-any.whl", hash = "sha256:7f157b7ac0a70b1dda77dc56c4fc063e8d21efcd2bb13759cd5b6839080405e7"}, + {file = "cohere-4.39.tar.gz", hash = "sha256:9e94bb1e5b2e2d464738e0ab3c99ed2879c043cccc90ecbeffd124e81867745d"}, +] + +[package.dependencies] +aiohttp = ">=3.0,<4.0" +backoff = ">=2.0,<3.0" +fastavro = ">=1.8,<2.0" +importlib_metadata = ">=6.0,<7.0" +requests = ">=2.25.0,<3.0.0" +urllib3 = ">=1.26,<3" + [[package]] name = "colorama" version = "0.4.6" @@ -706,6 +736,52 @@ files = [ {file = "faiss_cpu-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:98459ceeeb735b9df1a5b94572106ffe0a6ce740eb7e4626715dd218657bb4dc"}, ] +[[package]] +name = "fastavro" +version = "1.9.2" +description = "Fast read/write of AVRO files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastavro-1.9.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:223cecf135fd29b83ca6a30035b15b8db169aeaf8dc4f9a5d34afadc4b31638a"}, + {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e08c9be8c6f7eed2cf30f8b64d50094cba38a81b751c7db9f9c4be2656715259"}, + {file = "fastavro-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394f06cc865c6fbae3bbca323633a28a5d914c55dc2c1cdefb75432456ef8f6f"}, + {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7a7caadd47bdd04bda534ff70b4b98d2823800c488fd911918115aec4c4dc09b"}, + {file = "fastavro-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:68478a1b8a583d83ad6550e9dceac6cbb148a99a52c3559a0413bf4c0b9c8786"}, + {file = "fastavro-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:b59a1123f1d534743af33fdbda80dd7b9146685bdd7931eae12bee6203065222"}, + {file = "fastavro-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:887c20dc527a549764c91f9e48ece071f2f26d217af66ebcaeb87bf29578fee5"}, + {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46458f78b481c12db62d3d8a81bae09cb0b5b521c0d066c6856fc2746908d00d"}, + {file = "fastavro-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f4a2a4bed0e829f79fa1e4f172d484b2179426e827bcc80c0069cc81328a5af"}, + {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6167f9bbe1c5a28fbc2db767f97dbbb4981065e6eeafd4e613f6fe76c576ffd4"}, + {file = "fastavro-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d574bc385f820da0404528157238de4e5fdd775d2cb3d05b3b0f1b475d493837"}, + {file = "fastavro-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ec600eb15b3ec931904c5bf8da62b3b725cb0f369add83ba47d7b5e9322f92a0"}, + {file = "fastavro-1.9.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c82b0761503420cd45f7f50bc31975ac1c75b5118e15434c1d724b751abcc249"}, + {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db62d9b8c944b8d9c481e5f980d5becfd034bdd58c72e27c9333bd504b06bda0"}, + {file = "fastavro-1.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65e61f040bc9494646f42a466e9cd428783b82d7161173f3296710723ba5a453"}, + {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6278b93cdd5bef1778c0232ce1f265137f90bc6be97a5c1dd7e0d99a406c0488"}, + {file = "fastavro-1.9.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cd003ddea5d89720194b6e57011c37221d9fc4ddc750e6f4723516eb659be686"}, + {file = "fastavro-1.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:43f09d100a26e8b59f30dde664d93e423b648e008abfc43132608a18fe8ddcc2"}, + {file = "fastavro-1.9.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:3ddffeff5394f285c69f9cd481f47b6cf62379840cdbe6e0dc74683bd589b56e"}, + {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e75a2b2ec697d2058a7d96522e921f03f174cf9049ace007c24be7ab58c5370"}, + {file = "fastavro-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd2e8fd0567483eb0fdada1b979ad4d493305dfdd3f351c82a87df301f0ae1f"}, + {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:c652dbe3f087c943a5b89f9a50a574e64f23790bfbec335ce2b91a2ae354a443"}, + {file = "fastavro-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bba73e9a1822162f1b3a43de0362f29880014c5c4d49d63ad7fcce339ef73ea2"}, + {file = "fastavro-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:beeef2964bbfd09c539424808539b956d7425afbb7055b89e2aa311374748b56"}, + {file = "fastavro-1.9.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d5fa48266d75e057b27d8586b823d6d7d7c94593fd989d75033eb4c8078009fb"}, + {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b69aeb0d063f5955a0e412f9779444fc452568a49db75a90a8d372f9cb4a01c8"}, + {file = "fastavro-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ce336c59fb40fdb8751bda8cc6076cfcdf9767c3c107f6049e049166b26c61f"}, + {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:581036e18661f045415a51ad528865e1d7ba5a9690a3dede9e6ea50f94ed6c4c"}, + {file = "fastavro-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:39b6b5c3cda569c0a130fd2d08d4c53a326ede7e05174a24eda08f7698f70eda"}, + {file = "fastavro-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:d33e40f246bf07f106f9d2da68d0234efcc62276b6e35bde00ff920ea7f871fd"}, + {file = "fastavro-1.9.2.tar.gz", hash = "sha256:5c1ffad986200496bd69b5c4748ae90b5d934d3b1456f33147bee3a0bb17f89b"}, +] + +[package.extras] +codecs = ["cramjam", "lz4", "zstandard"] +lz4 = ["lz4"] +snappy = ["cramjam"] +zstandard = ["zstandard"] + [[package]] name = "firebase-admin" version = "6.2.0" @@ -1359,6 +1435,25 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] +[[package]] +name = "importlib-metadata" +version = "6.11.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.11.0-py3-none-any.whl", hash = "sha256:f0afba6205ad8f8947c7d338b5342d5db2afbfd82f9cbef7879a9539cc12eb9b"}, + {file = "importlib_metadata-6.11.0.tar.gz", hash = "sha256:1231cf92d825c9e03cfc4da076a16de6422c863558229ea0b22b675657463443"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -3141,7 +3236,22 @@ pytz = ">=2018.9" requests = ">=2.14.2" six = ">=1.14.0" +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "5f62f342e51b651110a2dde264cca468c11e067896cf9a2cf2eb8eb384557525" +content-hash = "e09032d19cea48bd6975884764686247c70a2f2dd35de72bda834a126dc8610c" diff --git a/api/pyproject.toml b/api/pyproject.toml index 3e46b8af..df54a61f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -27,6 +27,7 @@ pydantic = "^2.3.0" gunicorn = "^21.2.0" uvicorn = "^0.23.2" tenacity = "^8.2.3" +cohere = "==4.39" [tool.poetry.group.dev.dependencies] pytest = "^7.4.2" From 3e757f1dbed454f087042aa12b07c44d6a16e1aa Mon Sep 17 00:00:00 2001 From: David Xue Date: Wed, 3 Jan 2024 16:33:16 -0800 Subject: [PATCH 2/7] Add LLM Chain Filter + Shift Where Rerank Takes Place --- api/ask_astro/chains/answer_question.py | 42 +++++++++++++++---------- api/ask_astro/config.py | 8 +++++ 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/api/ask_astro/chains/answer_question.py b/api/ask_astro/chains/answer_question.py index ef63003b..7dfc4111 100644 --- a/api/ask_astro/chains/answer_question.py +++ b/api/ask_astro/chains/answer_question.py @@ -13,11 +13,11 @@ ) from langchain.prompts.prompt import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever -from langchain.retrievers.document_compressors import CohereRerank +from langchain.retrievers.document_compressors import CohereRerank, LLMChainFilter from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever from ask_astro.clients.weaviate_ import client -from ask_astro.config import AzureOpenAIParams, WeaviateConfig +from ask_astro.config import AzureOpenAIParams, CohereConfig, WeaviateConfig from ask_astro.settings import ( CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME, CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_TEMPERATURE, @@ -40,15 +40,12 @@ index_name=WeaviateConfig.index_name, text_key=WeaviateConfig.text_key, attributes=WeaviateConfig.attributes, - create_schema_if_missing="false", - k=100, - alpha=0.5, + k=WeaviateConfig.k, + alpha=WeaviateConfig.alpha, ) -compressor = CohereRerank(user_agent="langchain", top_n=4) -compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=hybrid_retriever) - -user_question_rewroding_prompt_template = PromptTemplate( +# Initialize a MultiQueryRetriever using AzureChatOpenAI and Weaviate. +user_question_rewording_prompt_template = PromptTemplate( input_variables=["question"], template="""You are an AI language model assistant. Your task is to generate 2 different versions of the given user @@ -59,8 +56,6 @@ of distance-based similarity search. Provide these alternative questions separated by newlines. Original question: {question}""", ) - -# Initialize a MultiQueryRetriever using AzureChatOpenAI and Weaviate. multi_query_retriever = MultiQueryRetriever.from_llm( llm=AzureChatOpenAI( **AzureOpenAIParams.us_east, @@ -68,18 +63,31 @@ temperature=MULTI_QUERY_RETRIEVER_TEMPERATURE, ), include_original=True, - prompt=user_question_rewroding_prompt_template, - retriever=compression_retriever, + prompt=user_question_rewording_prompt_template, + retriever=hybrid_retriever, ) -final_compressor = CohereRerank(user_agent="langchain", top_n=8) -final_compression_retriever = ContextualCompressionRetriever( - base_compressor=final_compressor, base_retriever=multi_query_retriever +# Rerank +cohere_reranker_compressor = CohereRerank(user_agent="langchain", top_n=CohereConfig.rerank_top_n) +reranker_retriever = ContextualCompressionRetriever( + base_compressor=cohere_reranker_compressor, base_retriever=multi_query_retriever +) + +# GPT-3.5 to check over relevancy of the remaining documents +llm_chain_filter = LLMChainFilter.from_llm( + AzureChatOpenAI( + **AzureOpenAIParams.us_east, + deployment_name=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME, + temperature=0.0, + ) +) +llm_chain_filter_compression_retriever = ContextualCompressionRetriever( + base_compressor=llm_chain_filter, base_retriever=reranker_retriever ) # Set up a ConversationalRetrievalChain to generate answers using the retriever. answer_question_chain = ConversationalRetrievalChain( - retriever=final_compression_retriever, + retriever=llm_chain_filter_compression_retriever, return_source_documents=True, question_generator=LLMChain( llm=AzureChatOpenAI( diff --git a/api/ask_astro/config.py b/api/ask_astro/config.py index 9e110607..e118d43f 100644 --- a/api/ask_astro/config.py +++ b/api/ask_astro/config.py @@ -61,3 +61,11 @@ class WeaviateConfig: index_name = os.environ.get("WEAVIATE_INDEX_NAME") text_key = os.environ.get("WEAVIATE_TEXT_KEY") attributes = os.environ.get("WEAVIATE_ATTRIBUTES", "").split(",") + k = os.environ.get("WEAVIATE_HYBRID_SEARCH_TOP_K", 100) + alpha = os.environ.get("WEAVIATE_HYBRID_SEARCH_ALPHA", 0.5) + + +class CohereConfig: + """Contains the config variables for the Cohere API.""" + + rerank_top_n = int(os.environ.get("COHERE_RERANK_TOP_N", 10)) From 00306ca3641048c9bb98923399d1221c85138b11 Mon Sep 17 00:00:00 2001 From: David Xue Date: Thu, 4 Jan 2024 15:59:13 -0800 Subject: [PATCH 3/7] Add Customer Boolean LLM Parser to Bypass Error with LLM Chain Filter --- api/ask_astro/chains/answer_question.py | 4 +- .../chains/custom_llm_filter_prompt.py | 39 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 api/ask_astro/chains/custom_llm_filter_prompt.py diff --git a/api/ask_astro/chains/answer_question.py b/api/ask_astro/chains/answer_question.py index 7dfc4111..0104ef7a 100644 --- a/api/ask_astro/chains/answer_question.py +++ b/api/ask_astro/chains/answer_question.py @@ -16,6 +16,7 @@ from langchain.retrievers.document_compressors import CohereRerank, LLMChainFilter from langchain.retrievers.weaviate_hybrid_search import WeaviateHybridSearchRetriever +from ask_astro.chains.custom_llm_filter_prompt import custom_llm_chain_filter_prompt_template from ask_astro.clients.weaviate_ import client from ask_astro.config import AzureOpenAIParams, CohereConfig, WeaviateConfig from ask_astro.settings import ( @@ -79,7 +80,8 @@ **AzureOpenAIParams.us_east, deployment_name=CONVERSATIONAL_RETRIEVAL_LLM_CHAIN_DEPLOYMENT_NAME, temperature=0.0, - ) + ), + custom_llm_chain_filter_prompt_template, ) llm_chain_filter_compression_retriever = ContextualCompressionRetriever( base_compressor=llm_chain_filter, base_retriever=reranker_retriever diff --git a/api/ask_astro/chains/custom_llm_filter_prompt.py b/api/ask_astro/chains/custom_llm_filter_prompt.py new file mode 100644 index 00000000..039d6467 --- /dev/null +++ b/api/ask_astro/chains/custom_llm_filter_prompt.py @@ -0,0 +1,39 @@ +from langchain.retrievers.document_compressors.chain_filter_prompt import ( + prompt_template, +) +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.prompts import PromptTemplate + + +class CustomBooleanOutputParser(BaseOutputParser[bool]): + """Parse the output of an LLM call to a boolean. Default to True if response not formatted correctly.""" + + true_val: str = "YES" + """The string value that should be parsed as True.""" + false_val: str = "NO" + """The string value that should be parsed as False.""" + + def parse(self, text: str) -> bool: + """Parse the output of an LLM call to a boolean by checking if YES/NO is contained in the output. + + Args: + text: output of a language model. + + Returns: + boolean + + """ + cleaned_text = text.strip().upper() + return self.false_val not in cleaned_text + + @property + def _type(self) -> str: + """Snake-case string identifier for an output parser type.""" + return "custom_boolean_output_parser" + + +custom_llm_chain_filter_prompt_template = PromptTemplate( + template=prompt_template, + input_variables=["question", "context"], + output_parser=CustomBooleanOutputParser(), +) From 1c6764ee86674e03dd88e6a5a75d72013c85a69a Mon Sep 17 00:00:00 2001 From: David Xue Date: Thu, 4 Jan 2024 22:12:18 -0800 Subject: [PATCH 4/7] Add comment --- api/ask_astro/chains/custom_llm_filter_prompt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/ask_astro/chains/custom_llm_filter_prompt.py b/api/ask_astro/chains/custom_llm_filter_prompt.py index 039d6467..a14eab42 100644 --- a/api/ask_astro/chains/custom_llm_filter_prompt.py +++ b/api/ask_astro/chains/custom_llm_filter_prompt.py @@ -32,6 +32,7 @@ def _type(self) -> str: return "custom_boolean_output_parser" +# custom_llm_chain_filter_prompt_template = PromptTemplate( template=prompt_template, input_variables=["question", "context"], From a5eb57dddee7d1fdcbe2d8f8a044d1d8ce41fcd5 Mon Sep 17 00:00:00 2001 From: David Xue Date: Thu, 4 Jan 2024 22:20:10 -0800 Subject: [PATCH 5/7] Remove comment --- api/ask_astro/chains/custom_llm_filter_prompt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/ask_astro/chains/custom_llm_filter_prompt.py b/api/ask_astro/chains/custom_llm_filter_prompt.py index a14eab42..039d6467 100644 --- a/api/ask_astro/chains/custom_llm_filter_prompt.py +++ b/api/ask_astro/chains/custom_llm_filter_prompt.py @@ -32,7 +32,6 @@ def _type(self) -> str: return "custom_boolean_output_parser" -# custom_llm_chain_filter_prompt_template = PromptTemplate( template=prompt_template, input_variables=["question", "context"], From ce1f61a366bb9fca83c88350230151cc8cb6e753 Mon Sep 17 00:00:00 2001 From: David Xue Date: Mon, 8 Jan 2024 17:06:12 -0800 Subject: [PATCH 6/7] Add env var create schema if missing for weaviate --- api/ask_astro/chains/answer_question.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/ask_astro/chains/answer_question.py b/api/ask_astro/chains/answer_question.py index 0104ef7a..5ab5bc06 100644 --- a/api/ask_astro/chains/answer_question.py +++ b/api/ask_astro/chains/answer_question.py @@ -43,6 +43,7 @@ attributes=WeaviateConfig.attributes, k=WeaviateConfig.k, alpha=WeaviateConfig.alpha, + create_schema_if_missing=WeaviateConfig.create_schema_if_missing, ) # Initialize a MultiQueryRetriever using AzureChatOpenAI and Weaviate. From ca74cdd0d919cda8cc72efe67800ad207b472892 Mon Sep 17 00:00:00 2001 From: David Xue Date: Mon, 8 Jan 2024 17:13:14 -0800 Subject: [PATCH 7/7] Add Weaviate config attr for create schema if missing --- api/ask_astro/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/ask_astro/config.py b/api/ask_astro/config.py index e118d43f..976306b6 100644 --- a/api/ask_astro/config.py +++ b/api/ask_astro/config.py @@ -63,6 +63,7 @@ class WeaviateConfig: attributes = os.environ.get("WEAVIATE_ATTRIBUTES", "").split(",") k = os.environ.get("WEAVIATE_HYBRID_SEARCH_TOP_K", 100) alpha = os.environ.get("WEAVIATE_HYBRID_SEARCH_ALPHA", 0.5) + create_schema_if_missing = bool(os.environ.get("WEAVIATE_CREATE_SCHEMA_IF_MISSING", "").lower() == "true") class CohereConfig: