
# <a id='toc1_'></a>[Query Augmentation](#toc0_)

A good way to improve the RAG's performance is to go beyond performing retrieval and generation based on the basic user query. Indeed, there are many ways to remaster the query to lead to better answer generation. In this notebook, we will look at the broad spectrum of query augmentation and study techniques that aim to enhcnace the query meaning or optimise the query itslef for better performance.

**Table of contents**<a id='toc0_'></a>    
- [Setup](#toc2_)    
- [Enhanced Query - Conversation](#toc3_)    
- [Optimized Query](#toc4_)    
  - [Simple reformulation](#toc4_1_)    
  - [Query Decomposition](#toc4_2_)    
  - [HyDE rewriting](#toc4_3_)    
  - [Step Back prompting](#toc4_4_)    
- [Combination](#toc5_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc2_'></a>[Setup](#toc0_)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

from dotenv import load_dotenv

os.chdir(Path.cwd().joinpath(".."))
print(Path.cwd())
load_dotenv(override=True)

In [None]:
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.schema import SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

from lib.models import embeddings, llm
from lib.prompts import (
    QA_DECOMPOSITION_PROMPT,
    QA_HISTORY_REWRITING_AGENT_PROMPT,
    QA_HISTORY_ROUTING_AGENT_PROMPT,
    QA_REFORMULATION_PROMPT,
    QA_STEP_BACK_PROMPT,
)
from lib.utils import (
    build_vector_store,
    drop_document_duplicates,
    load_documents,
    load_vector_store,
    split_documents_basic,
)

We instantiate all the tools we will use

In [None]:
BASE_CHUNK_SIZE = 512

# build vector_store
base_documents = split_documents_basic(load_documents("data/3_docs"), BASE_CHUNK_SIZE, include_linear_index=True)

build_vector_store(
    base_documents,
    embeddings,
    collection_name="3_docs",
    distance_function="cosine",
    erase_existing=False,
)

# Load Vector store / retriever
chroma_vector_store = load_vector_store(embeddings, "3_docs")
chroma_vector_store_retriever = chroma_vector_store.as_retriever()

# <a id='toc3_'></a>[Enhanced Query - Conversation](#toc0_)

Before we delve into techniques that try to optimize the existing query for better retrieval, we will first look at a very common challenge in RAG applications: how to incorporate conversational capabilities into a chatbot? 

Here we will look over this use case, when the user query refers to the previous messages in the conversation. There are two steps to take in order to make sure that the agent can answer the question well:

* Identify whether the user query is related to a previous message or not. This overlaps a lot with the routing part, as such classification is usually performed at the same time as the typical routing classification. For the sake of demonstration, we will do this separately. The best way to perform such classification is to use a LLM to judge.
* If the question is identified as conversation-related, another LLM is fed the entire conversation history and tasked to rewrite the query in a standalone way, which can then be used to perform retrieval.

Wen use the memory object to load chats into memory, and easily return the conversation history. We will instantiate our meory with a simple example.

In [None]:
memory = ConversationBufferMemory(return_messages=True, input_key="input", output_key="output")

memory.save_context(
    {"input": "What's the capital of France?"},
    {"output": "The capital of France is Paris."},
)

memory.load_memory_variables({})

We first define our helper functions to format the conversation history as well as the different llm prompts.

In [None]:
routing_prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessage(content=QA_HISTORY_ROUTING_AGENT_PROMPT),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{query}"),
    ]
)

history_rewriting_prompt = ChatPromptTemplate.from_template(QA_HISTORY_REWRITING_AGENT_PROMPT)

We now build the query rewriting chain as its own chain

In [None]:
history_rewriting_chain = history_rewriting_prompt | llm | StrOutputParser()

And we build the overall augmentation chain which uses it

In [None]:
history_augmentation_chain = (
    {
        "history": lambda _: memory.load_memory_variables({})["history"],
        "query": RunnablePassthrough(),
    }
    | RunnablePassthrough.assign(context_query=routing_prompt | llm | BooleanOutputParser())
    | RunnableLambda(lambda x: (history_rewriting_chain.invoke(x) if x["context_query"] else x["query"]))
)

We now try with two different examples

In [None]:
query = "What is 3 times 2?"
history_augmentation_chain.invoke({"query": query})

Here we see that the original query has not been altered

In [None]:
query = "What is its population ?"
history_augmentation_chain.invoke({"query": query})

But in this case, the query has been rewritten to include conversation context

# <a id='toc4_'></a>[Optimized Query](#toc0_)

We will now look at techniques that do not rely on adding more context to the query, but instead making sure said query is as optimized as possible to get the best results. Many of these techniques use llms as agents to change the query.

## <a id='toc4_1_'></a>[Simple reformulation](#toc0_)

The most simple of these techniques is a simple reformulation. The idea is to make the query more interrogative and precise, as sometimes the user input can be pretty vague. One can also integrate common abbreviations in the prompt in order to make them more explicit in the query for the retrieval step

We define a simple rewriting prompt and chain, with a vague example

In [None]:
simple_reformulation_chain = ChatPromptTemplate.from_template(QA_REFORMULATION_PROMPT) | llm | StrOutputParser()

simple_reformulation_chain.invoke("data FS applications")

We see that the user query is much more digestible.

## <a id='toc4_2_'></a>[Query Decomposition](#toc0_)

In some cases, the query may not be easy to answer as a whole, as it can be made up of multiple semi-distinct elements. In this case, a good method to use is query decomposition. The idea is very simple: as an LLM agent to break down the query into sub-queries that can then each be used for information retrieval separetely, the results being then ensembled into the context. We will showcase a simple implementation here, as Langchain does not have an integrated tool for this.

We first instantiate the prompt

In [None]:
decomposition_prompt = ChatPromptTemplate.from_template(QA_DECOMPOSITION_PROMPT)

We know build our chain

In [None]:
def _split(input_text: str) -> list[str]:
    return [chunk.strip() for chunk in input_text.split(".") if chunk.strip()]


query_decomposition_chain = decomposition_prompt | llm | StrOutputParser() | RunnableLambda(lambda x: _split(x))

We try various examples

In [None]:
query_decomposition_chain.invoke("what is the time ?")

In [None]:
query_decomposition_chain.invoke("give me the time and weather")

We can see that this query decomposition works well for these basic examples

## <a id='toc4_3_'></a>[HyDE rewriting](#toc0_)

In a different vein, an advanced and common technique used is HyDE rewriting. The idea behing HyDE is simple: generate a fake answer to the user query without a RAG call, (i.e. often with false information), and then use this answer in the retrieval process, embedding it and thus performing an answer-answer matching instead of question-answer. The assumption behind is that this matching is more precise, and thus the retrieved chunks are more relevant.

As a commonly used tool, HyDE benefits from Langchain implementation out-of-the-box. We will thus use the available tools, but the technique itself is easy to replicate with a well-crafted prompt. Here we will use the financial qa prompt, which out of many standard prompts works best for finance related queries.

In [None]:
hyde_pipeline = HypotheticalDocumentEmbedder.from_llm(llm, embeddings, "web_search")

In [None]:
print(hyde_pipeline.llm_chain.invoke("What are the primary applications of data science in the field of finance?"))

We can see the generated answer above, which looks like a chunk that could be retrieved. We then use the same HypotheticalDocumentEmbedder object to then automatically generate and embed the query using the defined models, and then go forward with our retrieval pipeline. We will show this in a very basic way:

In [None]:
chroma_vector_store.similarity_search_by_vector(
    hyde_pipeline.embed_query("What are the primary applications of data science in the field of finance?"),
    k=5,
)

## <a id='toc4_4_'></a>[Step Back prompting](#toc0_)

The last technique we will see is Step Back prompting, which works as a mix between sub query decomposition and rewriting. The idea is to have an LLM mimic human behavior by "pausing and reflecting" before answering a question, looker for higher level concepts or priniciples to guide the thought process. The goal is to extract an additional "step back" query and perform retrieval on both it and the original query, combining the results.

This can be simply implemented using an LLM call, parsing the output. We will consider the original query and step back question as subqueries in a similar fashion to the query decomposition pipeline.

In [None]:
step_back_pipeline = {
    "query": RunnablePassthrough(),
    "sb_query": ChatPromptTemplate.from_template(QA_STEP_BACK_PROMPT) | llm | StrOutputParser(),
} | RunnableLambda(lambda x: [x["query"], x["sb_query"]])

In [None]:
step_back_pipeline.invoke("What are the primary applications of data science in the field of finance?")

# <a id='toc5_'></a>[Combination](#toc0_)

Finally, we will try to combine all the previously seen techniques to create a thorough query optimisation chain. We will first take an input query with context, rewrite it to integrate conversation history, then break it down into sub-queries, then reformulate these queries and finally perform HyDE rewriting to retrieve chunks similar to synthetic answers.

We will reuse the chains we have defined in the previous cells.

We define a complex example with memory

In [None]:
memory = ConversationBufferMemory(return_messages=True, input_key="input", output_key="output")

memory.save_context(
    {"input": "What does AI mean "},
    {
        "output": "AI stands for Artificial Intelligence. It refers to computer systems or machines designed to perform tasks that typically require human intelligence, such as visual perception, speech recognition, decision-making, and language translation."
    },
)

memory.load_memory_variables({})

We build our subquery chain that performs reformulation, hyde rewriting and retrieval

In [None]:
sub_query_chain = (
    # The next step performs simple rewriting (see 2.1)
    simple_reformulation_chain
    # The next steps perform HyDE query rewriting (see 2.3)
    | HypotheticalDocumentEmbedder.from_llm(llm, embeddings, "web_search").llm_chain
    # The next step is the document retrieval
    | chroma_vector_store_retriever
)

We now integrate it into the full chain.

In [None]:
combination_chain = (
    # The next step performs chat history augmentation (see 1.)
    history_augmentation_chain
    # The next step performs query decomposition (see 2.2)
    | query_decomposition_chain
    # The next step runs the subquery chain for all identified sub queries (see previous cell)
    | RunnableLambda(lambda x: [sub_query_chain.invoke(sq) for sq in x])
    # Finally we ensmble all the retrieved documents in a unique list
    | RunnableLambda(lambda x: drop_document_duplicates([chunk for sq_context in x for chunk in sq_context]))
)

We test the full chain with a relevant query

In [None]:
query = "What are its applications to FS and how do you implement it"
combination_chain.invoke(query)