## Estimating the K hyperparamter using Query Complexity Score

### Overview

This notebook provides an explanation on how to dynamically estimate the **K hyperparameter** (number of docs to retrieve) during the retrieval phase using the **Query Complexity Score** (QCS).

### 1. Intuition

When building a comprehensive RAG system, one of the most important hyperparameters to tune is **K**, which represents the number of documents (before reranking, if any) to retrieve for each query. The choice of this hyperparameter can significantly impact the performances of the system, especially in terms of **precision** and **answer feasibility**:
 - A **low K** will lead to _higher precision_ (since the number of retrieved documents is low), but may result in missing relevant information, which leads to incomplete answers;
  
 - A **high K** will lead lo _lower precision_, but may provide more relevant information, which leads to more complete answers.

Unfortunately, **not all queries are created equal**: Some queries are more complex than others, and therefore may require more documents to be retrieved (higher K) in order to provide a complete answer. Other queries may be simpler and therefore require less documents (lower K) to achieve a satisfactory answer. For example, the query "_What's the capital of Italy?_" is a simple one and can be answered with a single chunk that contains the phrase "Rome is the capital of Italy"; on the other hand, the query "_What is the capital of Italy and what can i visit There? What are the main attractions?_" is more complex and may require multiple chunks to be retrieved in order to provide an answer.

This leads to the problem of **estimating the K hyperparameter** for each query, which is not a simple task. What if we could associate a score to each query that roughly reflects its complexity and, based on that score, estimate the K value to use for that query so that higher complexity scores lead to higher K values? This is the intution behind the **Query Complexity Score** (QCS), which is discussed in this notebook.

### 2. Solutions

#### 2.1 High Static K

The simplest approach to _"solve"_ this problem is setting an high, fixed K value. If the varying range of complexity of the queries that our system is going to receive is known a priori, a K value that is high enough can be fixed to provide a satisfactory answer for the most complex queries.
As already discussed, this approach has the drawback of leading to lower precision as well as higher costs and latency even for the simplest queries.

> Please note that these problems could be somewhat mitigated by using a **reranker** model on the retrieved chunks, but this notebook focuses on another approach entirely.

#### 2.2 Training a Model to Estimate K for potential queries

Another approach is to train a model to predict the value of K for a given query. This model can be trained on a dataset of queries and their corresponding K values, which can be obtained by generating syntethic queries of varying complexity and manually annotating them with an appropiate value of K. This approach works well if the training dataset is highly representative of the queries that the system is going to receive. However, it has the drawback of requiring a good dataset (which is difficult to obtain, especially with syntethic queries) and a good model that is able to generalize well to unseen queries. However, this approach is costly, time-consuming, and requires effort to maintain the model up-to-date in a system that is constantly evolving.

> For a thorough implementation of this approach, you can check out this [Medium Article](https://medium.com/@sauravjoshi23/optimizing-retrieval-augmentation-with-dynamic-top-k-tuning-for-efficient-question-answering-11961503d4ae) by Saurav Joshi.

#### 2.3 Ask an LLM

Of course, the **JALM** (Just Ask a Language Model) approach is always an option and, in most cases, the best one in terms of quality. a smart-enough LLM could be able to estimate the K value for a given query based on its complexity and, optionally, the context of the system.
Another approach relies on the use of **Query Composition**: Starting from the original query the LLM is _kindly_ asked to generate a set of small, atomic queries (addionally each with an associated K value) that reflects the decomposition of the original query into smaller, more manageable parts. The retrieval results for each sb-query is then merged using techniques like **Reciprocal Rank Fusion** or, optionally, each sub-query is answered separately and the sub-results are merged. This approach is particularly useful when the original query is too complex to be answered in a single step, but it requires a good LLM and adds a lot of complexity and latency to the system.


### 3. Query Complexity Score

To estimate the **Complexity Score** for a given query, we will use different heurstics:
- **Length of the query**: The longer the query (in terms of _tokens_), the more complex it is likely to be;
- **Number of different entities in the query**: The more entities are mentioned in the query, the more complex it is likely to be;
- **Number of different sentences or conjunctions in the query**: The more sentences or conjunctions are used in the query, the more complex it is likely to be;


#### 3.1 Preliminaries to QCS

To implement a function that calculates the QCS for a given query, we will take advantage of the [**SpaCy** library](https://spacy.io/), which provides different powerful NLP pipelines to perform the needed operations. In particular, we will use the **en_core_web_sm** pipeline, which is blazingly fast and amounts to ~12 MB of disk space.

Here is what we'll need:

In [None]:
%conda install --yes spacy

In [None]:
import spacy

try:
	nlp = spacy.load("en_core_web_sm")
except IOError:
	from spacy.cli.download import download
	download("en_core_web_sm")
	nlp = spacy.load("en_core_web_sm")

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


Now that we've loaded the SpaCy models, we can try them out on a test query:

In [None]:
test = "Rome is the capital of Italy. Italy is known for its rich history, art (e.g. David by Michelangelo), and culture."
doc = nlp(test)

print(f"Query: {test}")
print("*"*20)
print(f"Number of tokens: {len(doc)}")
print("-"*20)
print(f"Number of sentences: {len(list(doc.sents))}")
for i, sent in enumerate(doc.sents):
    if "and" in sent.text:
        print(f" Sentence {i+1} contains CONJUCTION ('and')")
print("-"*20)
print(f"Number of entities: {len(doc.ents)}")
for i, ent in enumerate(doc.ents):
    print(f" {i+1}. Entity: {ent.text}, Label: {ent.label_}")

Query: Rome is the capital of Italy. Italy is known for its rich history, art (e.g. David by Michelangelo), and culture.
********************
Number of tokens: 26
--------------------
Number of sentences: 2
 Sentence 2 contains CONJUCTION ('and')
--------------------
Number of entities: 5
 1. Entity: Rome, Label: GPE
 2. Entity: Italy, Label: GPE
 3. Entity: Italy, Label: GPE
 4. Entity: David, Label: PERSON
 5. Entity: Michelangelo, Label: PERSON


We get exactly what we need. Before we proceed, let's go over some observations:
- The **query length** is a good-ish indicator of query complexity, But it is not enough on its own: for example, the query "_What's the capital of the country next to France_?" isn't long, but its ambiguous and it requires some reasoning to provide an answer. Another point to consider is that we are using _number of tokens_ as a measure of length instead of _number of characters_. This is because the QCS is going to be a _weighted average_ of the different components, and considering number of tokens as a length measure allows us to have a slightly more consistent measure across queries;
  
- We are using the number of **Distinct** entities in the query, not the total number of entities. This is because we could have queries such as "Where's Italy and what's the capital of Italy?" that have the same entity mentioned multiple times, but since we are also considering the number of distinct sentences, this repetition shouldn't be considered while calculating the QCS;
  
- **Sentence Segmentation** is done smartly: In the example query, we have the "_e.g. David by Michelangelo_" that, while containing dots, isn't considered as separate sentences;
  
- Estimating the number of **Conjunctions** in a query is the trickiest part. We can't just count how many times the token "_and_" appears in the query. For example, the query "What are the Q3 earnings of Johnson and Johnson?" contains a conjunction, but it's part of a company name, so it shouldn't add complexity. What we ultimately want to count is the number of **Coordinating Conjunctions** (CCs) that connect main clauses in the query. We can use **Dependency Parsing** to achieve this.

> Please note that to calculate the QCS we are only considering the "and" token, but the set of coordinating conjunctions consists of the _FANBOYS_ conjunctions: **For**, **And**, **Nor**, **But**, **Or**, **Yet**, **So**. This is a matter of preference that can be adjusted based on the specific use case.

#### 3.2 Implementing QCS

Now that we finally have the full intuition behind the QCS, we can start implementing the main function and its components. Let's first implement the function that calculates the number of relevant CCs contained in the query:

In [None]:
from spacy.tokens import Doc

def count_ccs(doc : Doc) -> float:
    """
    Count the number of relevant Coordinating Conjunctions (CCs) in the phrase.
    """
    cc_count = 0.0
    for token in doc:
        if token.text.lower() == "and" and token.dep_ == "cc":
            head = token.head

            # CASE 1. Check if 'and' connects two verbal phrases
            # head verb could be AUX, so we check for root as well
            if head.pos_ == "VERB" or head.dep_ == "ROOT":
                if any(child.dep_ == "conj" and child.pos_ == "VERB"\
                for child in head.children):
                    cc_count += 1
            
            # CASE 2. Check if 'and' has a question as a conjunct
            # We check for "wh-" words
            elif any(child.dep_ == "conj" and \
            any(t.tag_ == "WRB" or t.tag_ == "WP" for t in child.subtree) \
            for child in head.children):
                cc_count += 1

    return cc_count

The `count_ccs` function looks for two main patterns to identify relevant CCs:

1. If the head of the token is a verb or a root, it checks if there are any **verb conjucts**. These tipically appear the second clause (after the "and") of the query;
2. If the first condition is not met, it checks wether the second clause contains **Nominal Subject** (nsubj) or **Direct Object** (dobj) relations. Their presence usually indicates that the second part of the query is a separate, complete cause.

Let's test the `count_ccs` function on some example queries:

In [79]:
def test_count_css():
    # This should be 1, as the second clause is a separate question
    assert count_ccs(nlp("What is the most important dish in Italy and how is it prepared traditionally?")) == 1.0
    # This should be 0, as the second clause doesn't contribute to the question
    assert count_ccs(nlp("I would like to visit Rome and Italy and I don't know which one to choose?")) == 0.0
    # This should be 1, the first "and" connects two nouns
    assert count_ccs(nlp("What are the Q3 earnings of Johnson and Johnson and how do they compare to the previous quarter?")) == 1.0

test_count_css()
print("✅ All tests passed! :)")

✅ All tests passed! :)


The `test_count_css` works as expected! Before we proceed with the implementation of the main function, let's define some **Normalization Constants**. These are mandatory since we are using arbitrary values to calculate the QCS, and we want to make sure that the final score is between 0 and 1. These constants can be tuned based on the expected query statistics, but for this example we will use some reasonable values:

In [106]:
MAX_LEN = 50   # max token length for a query
MAX_CC = 2      # max relevant conjunctions expected
MAX_SENT = 3    # max sentences expected in a query
MAX_ENT = 4     # max distinct entities expected

MIN_K = 1       # minimum value for K
MAX_K = 8       # maximum value for K

Let's now implement the main `calculate_qcs` function, which takes a `query`string and four floats `[len_w, cc_w, sent_w, ent_w]` representing the weight for each corresponding component of the weighted sum. The function returns the Query Complexity Score as a `float` in the range (0,1).

In [107]:
def calculate_qcs(query : str,
                  len_w : float = .3,
                  cc_w : float = .2, 
                  sent_w : float = .3, 
                  ent_w : float = .2,) -> float:
    """
    Calculate the Query Complexity Score (QCS) for a given query.
    """
    if len_w + cc_w + sent_w + ent_w != 1.0:
        raise ValueError("Weights must sum to 1.0")

    doc = nlp(query)

    # Calculate each components
    len_count = len(doc)
    cc_count = count_ccs(doc)
    sentence_count = len(list(doc.sents))
    entity_count = len(set([ent.text for ent in doc.ents]))

    # Normalize each component
    norm_len = min(len_count / MAX_LEN, 1.0)
    norm_cc = min(cc_count / MAX_CC, 1.0)
    norm_sent = min(sentence_count / MAX_SENT, 1.0)
    norm_ent = min(entity_count / MAX_ENT, 1.0)

    # Return weighted sum
    return len_w * norm_len + \
           cc_w * norm_cc + \
           sent_w * norm_sent + \
           ent_w * norm_ent

Great! Let's now test this function with some example queries of increasing complexity:

In [112]:
print(calculate_qcs("What's the capital of Italy?"))
print(calculate_qcs("What's the capital of Italy and how big is it?"))
print(calculate_qcs("What's an important dish in Italy and how is it prepared?"))
print(calculate_qcs("I'm an exchange student and i just got here in Italy. What can you tell me about the italian culture, and what famous dish can i eat in Rome?"))

0.192
0.22199999999999998
0.32799999999999996
0.748


As you can see, the `calculate_qcs` functions is correctly estimating a QCS for each query, with complex queries achieving higher scores. Now, estimating the K value is pretty straightforward: we can use a linear function that maps the QCS to a range of K values.

In [116]:
def estimate_k(query : str,
               min_k: int = MIN_K,
               max_k: int = MAX_K) -> int:
    """
    Estimate the K value based on the QCS.
    """
    return int(min_k + (max_k - min_k) * calculate_qcs(query))

Let's finally test the `estimate_k` function with the same queries from before:

In [118]:
print(estimate_k("What's the capital of Italy?"))
print(estimate_k("What's the capital of Italy and how big is it?"))
print(estimate_k("What's an important dish in Italy and how is it prepared?"))
print(estimate_k("I'm an exchange student and i just got here in Italy. What can you tell me about the italian culture, and what famous dish can i eat in Rome?"))

2
2
3
6


We can observe that the `estimate_k` function is correctly estimating the K value for each query, with more complex queries having higher K values. This is exactly what we wanted to achieve!

### 4. Conclusion

Overall, the **Query Complexity Score** proves to be a useful, extremely fast and lightweight way to dynamically estimate the complexity of a query and, subsequently, the K value to use for retrieval. The main advantage of this approach is that it doesn't require any prior training data or fine-tuning, and it can be easily adapted to different contexts by adjusting the provided normalization constants and/or the weights of the different components. 

To get the most out of this technique, it is recommended to use it in conjunction with other techniques such as **Query Decomposition** by following this approach:
1. Decompose the original query into smaller, manageable sub-queries using an LLM;
2. Estimate the QCS and K value for each sub-query;
3. Retrieve the documents for each sub-query using the estimated K value;
4. Merge the retrieved documents using techniques such as **Reciprocal Rank Fusion** or answer each sub-query separately and merge the sub-results.

The QCS function could also be further improved by adding more components, such as presence of **quantifiers** (e.g. "all", "some", "most") or **negations** (e.g. "not", "never"), which can also affect the complexity of the query. 

#