# Chunking Guidance for Retireval Augmentation Generation 
Retrieval Augmentation and Generation (RAG) systems are gaining traction as a method to bootstrap a Large Language Model (LLM) using your own data.

Implementing a proof of concept that passes the initial “vibe” check can be achieved relatively quickly, especially when using a framework that abstracts many of the details. This approach might get an engineer 70-80% of the way there. However, as with any data-driven system, the real challenge lies in the details when it comes to extracting the extra performance that provides confidence in end users.

In this notebook, we’ll explore how to best comprehend a corpus of documents, conduct some analysis to help determine the parameters of your chunking strategy, and how to conduct experiments to identify what truly works for your data. We’ll execute the baseline experiment "from first principles" i.e. where it makes sense we won't use any fancy tooling that abstracts what's going on under the hood. For subsequent experiments, we’ll utilize separate notebooks and demonstrate how some popular tools can significantly streamline the process!

The chunking strategy is just one of many factors that contribute to the success or failure of a RAG system. We’ll also touch on other areas such as retrieval, generation, and evaluation, but we won’t delve into every aspect in this notebook. Our focus will be on providing a solid foundation and understanding of the key components of a RAG system.
> KEY TAKEAWAY: This repository will help you understand and the decisions that can be made when evaluating chunking strategies, and how to measure their impact.

## RAG System Components
![Question Answering Worklfow](./images/Question-Answering-Workflow.jpg)

### Purpose

#### What it is

This notebook aims to give direction to software engineers in how they might approach evaluating, and selecting a chunking methodology for a RAG system. It is not meant to be an accelerator or template. There are other great resources for that like the [RAG Experiment Accelerator](https://github.com/microsoft/rag-experiment-accelerator). Instead this is meant to be complementary to those assets, highlight the thinking and decisions that you might make whilst approaching a chunking problem. If there's one thing we'd like you to take away from this notebook, it's that every step represents a decision that impacts the performance of your application.

This notebook prepares the data ready for the experiments defined in the [experiments folder](./experiments/).

#### What it is not

This notebook does NOT intend on covering end to end evaluation of RAG systems, as that is a much broader topic that dives into information retrieval. Resources for tuning those elements of a RAG system can be found [here](). Nor is it intended as a template for production deployments.


## Outline

1. [Exploring your data](#exploring-your-data)
2. [Chunks, Words, Tokens and Contexts](#chunks-words-token-and-contexts)
3. [Experiment Setup](#experiment-setup)
4. [Building an Evaluation Dataset](#building-an-evaluation-dataset)
Let's get started!

## Loading a Corpus

First things first, let's load some text data to work with. Let's go with the [pubmed summarisation dataset](https://huggingface.co/datasets/ccdv/pubmed-summarization). 

> NOTE: In practice, it's unlikely that you'll receive data that is in a uniform format that's easy to work with - it may require pre-processing steps.

We'll download from hugging face, but for simplicity we'll convert the dataset to pandas, which most data pro's are familiar with. This step will need to download ~3.5GB of data the first time you run it, but will be stored and the load step much faster fo subsequent runs.

> NOTE: In practice, your data may be too large to load into a pandas dataframe. There are a number of best practices around this, consult your local data scientist/engineer for the best approach for your use case.

In [None]:
# Imports

from datasets import load_dataset
from uuid import uuid4
from pprint import pprint
import os
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import tiktoken as tk
import random
import json
from multiprocessing import Pool

# Set to pubmed or arxiv
publication = 'pubmed'

dataset = load_dataset(f'scientific_papers',publication,split='train', trust_remote_code=True)

# Convert to a pandas dataframe and do some housekeeping
ds = dataset.to_pandas()
ds['doc_id'] = [str(uuid4()) for _ in range(len(ds))]
ds['article_len'] = ds['article'].apply(lambda x: len(x.split()))

ds.drop(columns=['abstract'],inplace=True)

# Take a look at our data
display(ds.head())

## Exploring your data

One of the main things that we are concerned with when building RAG systems is the length of the data that we intend to feed into an LLM. Even with context windows becoming increasingly large, we still don't really know if simply throwing more data at an LLM is effective, and it is definitely not quick. In general it's a good idea to have a solid understanding of the data you'll be working with and consider non functionals like performance and cost up front.

Let's start by understanding how long a typical article is, and get a feel for the boundaries and overall distribution of our data. An easy way to do this is to look at some basic stats, and draw a histogram.
> KEY TAKEAWAY: Whilst bigger models are capable, often a smaller model with better curated data will perform just as well, cost less, and result in a more responsive application.


In [None]:
display(ds.describe())

fig = plt.figure(figsize=(10, 6))
sns.histplot(ds['article_len'], bins=50, kde=True)
plt.show()

As we can see from both the table and the diagram plot, there is a heavy right skew - meaning that there are a small number of extremely large values. 

Deciding how to handle these unusually long articles is crucial. For some data problems you'd look to exclude outliers. But given this is an information retrieval problem, we'd like to keep as much data as possible. 

What we need to do, is understand  *why* some of these articles are so long, and deal with them accordingly.

First, lets zoom in on the distribution and if we can find a more reasonable cut off point. This will help us better manage the data and ensure the effectiveness of our information retrieval system.

> KEY TAKEAWAY: There will almost always be outliers in a dataset. It is important to understand why they are different, and how you handle them depends on your use case!

In [None]:
ds.describe(percentiles=[0.75,0.8, 0.9,0.95, 0.99])

The problem is definitely in the high end of town, with a jump from 10k to 112k in number of words for the last percentile. Let's take a closer look at the raw data.

In [None]:
# Get the indices of the articles with the longest length
longest_articles = ds['article_len'].nlargest(5).index
for idx in longest_articles:
    print(f'Article Length: {ds["article_len"][idx]}\n')
    print(ds['article'][idx]+'\n')
    print('\n')

At first glance, it appears that the two main causes of long documents are either:
- LaTeX package inclusions (mathematical formatting for scientific documents)
- Data tables (appear to be from pharmaceutical research)

In practice, we would want to spend more time understanding the drivers behind the outliers, and address as many as possible. However, for the purposes of this exercise we will focus on the LaTeX issue.Data tables could be solved through an application of difference pdf cracking techniques (e.g. [Azure Document Intelligence](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/concept-retrieval-augmented-generation?view=doc-intel-4.0.0)), but that is out of scope for this notebook. Let's remove the the lines which include LaTeX and take another look at the adjusted distributions.

In [None]:
from helper.general import remove_latex_packages

# Remove LaTeX package inclusions from the articles
ds['article'] = ds['article'].apply(remove_latex_packages)

# Recalculate article lengths
ds['article_len'] = ds['article'].apply(lambda x: len(x.split()))

display(ds.describe(percentiles=[0.75,0.8, 0.9,0.95, 0.99]))

This has made a difference, but 95k is still very large. For now, let's exclude the longer documents, storing them in another dataframe for analysis later.

Once we remove the odd docs, we'll check our distribution again to make sure that we now have something workable.

> KEY TAKEAWAY: When removing outliers, or error records, it's important not to discard them. Store them somewhere to be reviewed at a later date.

In [None]:
from helper.general import remove_over_percentile

#apply helper function from utils module
ds_99pct, ds_outliers = remove_over_percentile(ds, 'article_len', .99)

fig = plt.figure(figsize=(10, 6))

# Plot the first histogram on the first subplot
sns.histplot(ds_99pct['article_len'], bins=100, kde=True)
plt.title('Article Length')


# Display the plots
plt.show()

display(ds_99pct.describe())

Let's take a quick look at the smallest articles too. The above histogram indicates that there are a number of articles that have length of less than 500 words. Let's check formally and then take a peek.

In [None]:
count_tiny_articles = ds_99pct['article_len'][ds_99pct['article_len'] < 500].count()

print(f'Number of articles with less than 500 words: {count_tiny_articles}')

smallest_articles = ds_99pct['article_len'].nsmallest(count_tiny_articles).index
for idx in smallest_articles:
    print(f'Article Length: {ds_99pct["article_len"][idx]}\n')
    print(ds_99pct['article'][idx]+'\n')
    print('\n')

Scanning through the articles there are some empty records, but the majority seem to be fractions of documents rather than complete articles. In this specidic case, we can reasonably conclude that these are invalid records and exclude them from our system. It's worth calling out that this exclusion is different from excluding the large ones, as they are malformed at source and there is little that we can do to remediate without calling a friend! 

>NOTE: In practice, we would go back to the customer and review these records to decided whether it's appropriate to keep them or not.

In [None]:
ds_clean = ds_99pct[ds_99pct['article_len'] >= 500]

Now that we have our cleansed dataset, let's start talking about tokens, why it's good to know about them, and how they relate to designing a RAG system.

## Chunks, Words, Token and Contexts

Despite feeling like LLMs converse in our language, there's a few things that go in behind the scenes that translate our verbiage into something an algorithm understands. 

Firstly, the text is `tokenized`, which means words are split into a list of `tokens`. Think of this a bit like stemming in NLP. For shorter words, the ratio of tokens to words can be 1:1 (i.e. the word = the token), but for longer, or more complex words the ratio can be far higher. 

These lists are then mapped into numerical vectors that the algorithm can understand. For a given corpus, we could work out the exact ratio - in fact, let's do that!

(for a visual explanation of tokens see [this](https://www.tokencounter.io/) excellent resource)

First we will apply the appropriate encoding to the text, then we simply sum up the number of tokens and words in the corpus and calculate the number of tokens / word.

> KEY TAKEAWAY: Tokens are the currency of LLMs - they're how we measure inputs and outputs and often what pricing and API limits are based on. 

In [None]:
# run the tokeniser over the articles and store the results in the DataFrame
# Given the size of the dataet this can take up to 10 minutes Let's select a 
# random sample of 5000 articles to reduce the time down to ~ 10 seconds
# Feel free to grab a beverage and run the analsis on the entire dataset to
# compare the results if you have the time!

subset = True

if subset == True:
    random.seed(42)
    random_doc_ids = random.sample(list(ds_clean['doc_id'].unique()), 5000)

    # Subset the DataFrames
    ds_clean = ds_clean[ds_clean['doc_id'].isin(random_doc_ids)]

encoding = tk.encoding_for_model('gpt-3.5-turbo')

article_tokens = ds_clean['article'].apply(encoding.encode)

# check if columns already exist
if 'article_tokens' in ds_clean.columns:
    ds_token_count_subset = ds_clean.drop(columns=['article_tokens'])

ds_clean = ds_clean.assign(article_tokens=article_tokens)
ds_clean['article_tk_len'] = ds_clean['article_tokens'].apply(lambda x: len(x))
ds_clean['token_ratio'] = ds_clean['article_tk_len']/ds_clean['article_len']

total_words = ds_clean['article_len'].sum()
total_tokens = ds_clean['article_tk_len'].sum()

ratio = total_tokens/total_words

print(f'The ratio of tokens to words is {ratio:.2f}')
print(f"The largest article has {ds_clean['article_tk_len'].max()} tokens")
print(f"The smallest article has {ds_clean['article_tk_len'].min()} tokens")
print(f"The mean number of tokens is : {ds_clean['article_tk_len'].mean()}")

Now we know the overall ratio we can make some assumptions around cost and how much information we feed the generation step as context.

Before we move on, let's just take a look at the distribution of ratios - some articles might use more complex language than others. Whilst this ratio is not directly useful, it could be a red flag that some data is being excluded. You may look at doing outlier analysis to dive deeper into how your encoder is functioning, and is not excluding words that are specific to the domain, and whether or not your embedding model is appropriate for the use case. We need to make sure this step is correct as it's our last touch point with the LLM.

For this we'll use a violin plot which is similar to a box plot, but with more flexibility.

The plot below shows that we have a few articles that have high token ratios - an interesting problem for another time!

> KEY TAKEAWAY: Derived stats like token ratios can be useful in spotting outliers in your data, or areas where your encoding model is not functioning as you expect.

In [None]:
# create a violin plot of token ratios
fig = plt.figure(figsize=(10, 6))
sns.violinplot(x='token_ratio', data=ds_clean)
plt.title('Token Ratio')
plt.show()

### Why do we care about this?

How many records do we want to include in our Augmentation step when constructing the generation prompt? Say we're using GPT-35-Turbo, we have aprx 4000 tokens to play with. This is both **input and output**.

Let's assume the following:

1. We have a prompt template which is a total of 500 tokens, including our guardrails, instructions and any other boiler plate commentary that needs to be input to the generation step.
2. We allow for up to 500 tokens in a response. 

$$Guardrails + Instructions + Estimated Response Length + Retrieved Context < Model Token Limit$$


Meaning with **GPT-35-Turbo**:
$$ 500 + 500 + Retrieved Context < 4000 $$

or $ Retrieved Context < 3000 $

This leaves us with 3000 tokens to play with. If we assume a chunk size  of 500, that gives us 6 records in our retrieval step. In fact, this might be a good starting point. It glosses over whether or not 500 tokens is enough to capture something that is semantically relevant and (more importantly!) useful for the generation step, but we'll take a look at that as we experiment with different chunking strategies.

Now we have some idea of a driving non functional contraint. Let's talk about what makes a good chunk.

> KEY TAKEAWAY: By understanding token counts, we can estimate and budget for what we're going to send to an LLM in the generation step.

### What makes a good chunk?

Let's start by really addressing what a chunk is intended to do. It's sole purpose is to provide context to a model that enhances the available "knowledge" available to the LLM to better answer a given question.

To do that well, a chunk should:

1. Be relevant to the query or task
2. Be factually accurate and avoid misleading information
3. Be specific - focused information rather than overly broad content
4. Concise and information dense. Conveying information effectively within a limited space
5. Referenceable. A user should always be able to go back and see the chunk in context of it's parent document

Let's talk about each in turn:

#### Relevance

This is primarily the focus of optimising search results. What metrics do you choose to select you chunks when executing the retrieval step. [Cosine Similarity]() and the [dot product]() are two commonly used metrics. You may expand this by performing a hybrid search (adding metadata to your quesry - for example published dates, authors etc.), or even looking at strategies for [re-ranking](). The specifics of these techniques are beyond the scope of this notebook as they fall into a separate, and expansive discipline (Maybe another notebook!). In our example we are simply taking the top n records based on cosine similarity.

#### Factual Accuracy

This is fundamentally a data governance problem. Where did your data come from, is it up to date, can you trust the source. An example provided in a recent epiisode of [Practical AI]() gave the example of an HR bot being queried on annual leave policies. One data source could be the official HR policy which contains relevant, up to date information, whereas another source could be chat logs that are dated and included discussions from other georgaphies that aren't relevant to the users context. In our example, we are essentially outsourcing the governance problem by trusting the Hugging Face dataset and PubMed as sources.

#### Specificity

This is where our chunking strategy starts to play a part. How many topics does a chunk cover? Is it general in nature, or does it cut to the point. It may be that we find that chunk length correlates well with specificity. Perhaps we need ot look at topics, or semantic consistency in chunks. More on all of this to come! Increased specificity doesn't always equate to better performance, for example what if users ask general, conversational questions... do they always expect a scientific and targetted response? Probably not.

#### Concise

We've illustrated the finite number of tokens we have to play with in the previous section. We ideally want information dense chunks that provide the correct context without rambling and consuming too many tokens.

#### Referencable

Whilst the ultimate answer is generated by an LLM, often users will want to check the source. It is best practice to provide both the context, and the source of the context items back to the user for verification and / or further research. We'll achieve this by making use of metadata in our vector index.

Now we understand our data, how we need to manage the inputs to a RAG system, and what a good output looks like. Let's set up our experiments.

> KEY TAKEAWAY: A solid chunking strategy will consider chunks across the following dimensions: relevance, factual accuracy, specificity, conciseness, and referencability

## Experiment Setup

At Microsoft, we use a framework called the [AI Garden](https://github.com/cse-labs/ai-garden/tree/main). This allows us to group related experiments into "experiment families" and collaborate across groups and engagements in sharing knowledge and learnings.

In this repo we have the [experiments](experiments) directory which contains our experiment family and distinct experiment plans and results. We also include Architectural Decision Records (ADRs) which include important technical decisions that impact the outcome, along with reasoning. I encourage you to check these out! 

> KEY TAKEAWAY: How we conduct, execute and log our experiments is important. Done well, this will allow others to understand and build on our work.

### Subset data

Given this is an illustration of the process and thinking, now that we have a clean dataset we're going to go one step further an subset the data down to 50 randomly selected articles - mostly because I'm terrified of my GPT4 bill! 

To stress this point, we'll assume from here on in that the subset is whole of the corpus. 

> NOTE: In practice, to test your retrival steps you would need to generate representative evaluation data from the entire corpus, and run retrieval against it.

In [None]:
# select a random 50 unique doc_ids and subset ds_cleam - note we will use this again later!
random.seed(42)
random_doc_ids = random.sample(list(ds_clean['doc_id'].unique()), 50)

# Subset the DataFrames
ds_subset = ds_clean[ds_clean['doc_id'].isin(random_doc_ids)]

#write out the dataset to a csv file
ds_subset.to_csv('data/docs_subset.csv', index=False)

### Building an Evaluation Dataset

Now that we have the data, let's think about how we are going to evaluate our chunking strategies. The evaluation metrics are outlined in the [experiment documentation](./experiments/00-chunking-strategies-family.md). The foundation of evaluation is a good set of question / answer pairs.

Why do we do this?

#### Simulating Real-World Use:

- QA pairs mimic how users interact with a RAG system. Users ask questions, and the system retrieves documents and generates answers based on those documents.
- By testing with good QA pairs, you ensure the RAG system performs well in scenarios it's designed for.

#### Evaluating Retrieval Accuracy:

- Good QA pairs often have answers that can be found within a relevant document. This allows you to see if the RAG system retrieves the documents that actually hold the answer.
- Poor retrieval throws off the entire RAG process, so evaluating retrieval accuracy is crucial.

#### Assessing Answer Quality:

- QA pairs with well-defined answers provide a benchmark for the RAG system's generation capabilities
- You can compare the generated answer to the actual answer in the retrieved document to see if the RAG system effectively uses the information.

#### Identifying System Biases:

- A diverse set of QA pairs helps identify potential biases in the RAG system.
- For instance, if the system struggles with specific question types or topics, the QA pairs will reveal these weaknesses.

#### Not just any QA pairs will do:

- Good QA pairs should be well-formed, grammatically correct, and cover a range of difficulty levels and topics relevant to your RAG system's intended use.
- In essence, good QA pairs provide a realistic testing ground to assess how well your RAG system retrieves information and generates accurate and relevant answers.

Now, I don't have the time or expertise to generate a set of Q&A for these pubmed articles. A popular (but flawed!) approach is to use a leading LLM to read the documents and create the question answer pairs for you. This would be reviewed, and augmented by a selection of subject matter experts before being used as a the basis for measurement. There are a number of problems with using this approach in the real world - as LLms tend to generate simple "factual questions". I can't stress enough the importance of having human subject matter experts supply, or at least review and augment this evaluation dataset to ensure it's representative of the target user intent. We see this all over the place in LLM leaderboads using datasets like [SQUAD]().

If you're interested in how to generate a good ground truth dataset - let the team know and we'll work on writing some guidance.

To do this, we will create a prompt template that submits each article, and generates a set of 5 question answer pairs in a specific format. Once we have the template, we can submit the prompts to our model of choice. 

Using GPT4 this should take about 5 minutes.

> NOTE: Given this requires a structured output, we recommend using gpt-4 for this step. There are other ways to apply structure to outputs, but for simplicity's sake we'll use a the sledgehammer.

> KEY TAKEAWAY: How you construct your evaluation dataset matters. Generating synthetic data is useful to bootstrap, but to have confidence in a production system, real user / SME feedback is required for the Q&A evaluation data.

In [None]:
from helper.openai_utils import general_prompt, create_client
from helper.general import convert_to_dict
from rag.data_prep import generate_qa_prompt

# Create an OpenAI client
oai_client = create_client()
model = os.getenv("QA_MODEL")

def _process_article(article):
    return general_prompt(oai_client, generate_qa_prompt(article), model=model)


def create_qa_pairs(client, model, articles, save_output=True, parallel=True):
    if parallel is True:
        with Pool() as pool:
            results = pool.map(_process_article, articles)

    else:
        prompts = [generate_qa_prompt(article) for article in ds_subset["article"]]
        results = [general_prompt(client, prompt, model=model) for prompt in prompts]

    # Save the results to a file
    qa_pairs = convert_to_dict(results)
      
    questions = [pair["question"] for pair in qa_pairs]
    answers = [pair["answer"] for pair in qa_pairs]

    # Create a DataFrame with the questions and answers
    qa_df = pd.DataFrame({"question": questions, "ground_truth": answers})

    if save_output is True:
        qa_df.to_csv('data/qa_pairs.csv', index=False)
        
    return qa_df

# if the data is already saved, load it
if os.path.exists('data/qa_pairs.csv'):
    qa_df = pd.read_csv('data/qa_pairs.csv')
else:
    qa_df = create_qa_pairs(oai_client, model, ds_subset["article"], save_output=True, parallel=True)

Good work getting this far! Lets take our outputs and start experimenting!