# Selecting an embedding model for your custom data

In a recent blog post, ["Understanding embedding models: make an informed choice for your RAG"](https://unstructured.io/blog/understanding-embedding-models-make-an-informed-choice-for-your-rag), we have explored what you need to know in order to navigate the [Hugging Face MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) effortlessly and select a baseline text embedding model. 

You are likely to find more than one candidate model that meets your criteria. In this case, you should evaluate the candidates on your own data. Good performance on academic benchmarks is oen thing, but your custom data has its own nuances, domain-specific language, and other unique traits.  

In this example we'll:
 * Pick three embedding models from the MTEB leaderboard: [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5), [mukaj/fin-mpnet-base](https://huggingface.co/mukaj/fin-mpnet-base), and [Snowflake/snowflake-arctic-embed-l](https://huggingface.co/Snowflake/snowflake-arctic-embed-l)
  * Generate a synthetic dataset to evaluate their retrieval performance on custom data consisting of PDF files
  * Set up a preprocessing pipeline using the best embedding model for our data

## Install the required dependencies

First, let's install the libraries that we will be using: 

* `unstructured` & `unstructured-ingest` for preprocessing documents. 
* `python-dotenv` to load the environment variables from a `.env` file 
* `chromadb` and `langchain` to set up retrievers with different embedding models
* `ollama` to prompt an LLM to generate a synthetic evaluation dataset


To use this example, you'll need to get an [Unstructured API key](https://unstructured.io/api-key-hosted). The Unstructured Serverless API comes with a 14-day trial capped at 1000 pages per day. 

In [None]:
!pip install -qU "unstructured[pdf, embed-huggingface]" unstructured-ingest python-dotenv langchain chromadb ollama

## Load the environment variables

Store your environment variables, such as `UNSTRUCTURED_API_KEY` and `UNSTRUCTURED_URL` in a `.env` file, then load them here. 

In [1]:
import os
import dotenv

dotenv.load_dotenv('.env')

True

## Preprocess PDFs from a source location

The data we use in this example is stored in 2 large PDF files (feel free to substitute with your own data). These are annual financial reports (Form 10-K) for the year 2023 from two large companies - Walmart Inc., and Exxon Mobil Corporation. These documents are publicly available on these companies' respective websites. We'll use them as an example of domain-specific data (financial industry). 

As we don't have the actual user queries to evaluate retrieval performance, the next best thing is to generate an evaluation dataset from custom data. For this, we will first need to preprocess the PDFs. First, let's do the necessary imports: 

In [218]:
from unstructured_ingest.v2.pipeline.pipeline import Pipeline
from unstructured_ingest.v2.interfaces import ProcessorConfig
from unstructured_ingest.v2.processes.connectors.local import (
    LocalIndexerConfig,
    LocalDownloaderConfig,
    LocalConnectionConfig,
    LocalUploaderConfig
)
from unstructured_ingest.v2.processes.partitioner import PartitionerConfig
from unstructured_ingest.v2.processes.chunker import ChunkerConfig
from unstructured_ingest.v2.processes.embedder import EmbedderConfig

from unstructured.staging.base import elements_from_json
from unstructured.staging.base import elements_to_dicts

import json
import ollama
import pandas as pd

To process the PDFs from a local directory, set up Unstructured ingest pipeline with a local source connector, and a local destination connector. Unstructured supports dozens of source and destination connectors, so you can easily modify this pipeline to ingest documents from an S3 bucket, or Azure blob storage, or Google Drive, or any of the other supported source. 
At this stage, let's keep the destination local. Once we're done evaluating models we can modify and re-run the pipeline to add an embedding step and a vector store as a destination. 

The Unstructured processing pipeline can be assembled from a number of configurations: 

In [None]:
Pipeline.from_configs(
    context=ProcessorConfig(
        verbose=True,
        tqdm=True,
        num_processes=20,
    ),
    indexer_config=LocalIndexerConfig(input_path="PDFS"),
    downloader_config=LocalDownloaderConfig(),
    source_connection_config=LocalConnectionConfig(),
    partitioner_config=PartitionerConfig(
        partition_by_api=True,
        api_key=os.getenv("UNSTRUCTURED_API_KEY"),
        partition_endpoint=os.getenv("UNSTRUCTURED_URL"),
        strategy="fast",
        ),
    chunker_config=ChunkerConfig(
        chunking_strategy="by_title",
        chunk_max_characters=1500,
        chunk_overlap = 150,
        ),
    uploader_config=LocalUploaderConfig(output_dir="local-ingest-output")
).run()

* `ProcessorConfig` describes general behavior such as logs verbosity, number of processes, etc.
* `LocalIndexerConfig`, `LocalDownloaderConfig`, and `LocalConnectionConfig` control data ingestion from a local source, you only need to provide a path to your local directory with PDFs here.
* `PartitionerConfig`: use it to supply your credentials for the Unstructured Serverless API, and customize the partitioning behavior, e.g. what partitioning strategy to use, whether to exclude some types of metadata, etc. In this case, we use `fast` strategy to partition the files, as the PDFs are not complex and contain text only.
* `ChunkerConfig`: after partitioning we will chunk the documents into meaningful sized chunks that are not exceeding the input size of all the embedding models we'll be evaluating.
* `LocalUploaderConfig`: specify a local directory to load the processed files into.   

## Create an evaluation dataset

Once we have preprocessed the documents into chunks, let's build a synthetic evaluation dataset. To load all the processed files from the output directory, we can use the `elements_from_json` function for each JSON file:

In [2]:
def load_processed_files(directory_path:str):
    """
    Reads all preprocessed data from JSON files in the given directory and returns elements as a list

    Args:
        directory_path (str): The path to the directory containing JSON files.
    """
    elements = []
    for filename in os.listdir(directory_path):
        if filename.endswith('.json'):
            file_path = os.path.join(directory_path, filename)
            try:
                elements.extend(elements_to_dicts(elements_from_json(filename=file_path))) 
            except IOError:
                print(f"Error: Could not read file {filename}.")
    
    return elements

In [3]:
elements = load_processed_files("local-ingest-output")

len(elements)

1082

Let's add a helper function that will parse LLM's responses into a dictionary, and add `context` (chunk content) and `chunk_id` of the chunk the question is based on, so that we could later see whether we retrieve the chunk or not:

In [60]:
def convert_qa_string_to_dict(input_string, chunk_id, chunk_text):
    """
    Converts a string response from an LLM to a Python dictionary with question-answer-context entries.
    """
    try:
        result = json.loads(input_string)
        questions = result["questions"]
        for question in questions:
            question['id'] = chunk_id
            question['context'] = chunk_text
        return questions
    except json.JSONDecodeError as e:
        print(f"Error parsing JSON: {e}")
        return []


For the synthetic evaluation dataset, we'll go over the chunks, and for each chunk we'll prompt the local `llama3.1:8b` model to generate two question/answer pairs.  

In [61]:
def generate_chunk_qa_pairs(element):
    """
    Prompts a local LLM to generate a question-answer pairs for a document chunk
    """
    
    prompt = """
    You are an assistant specialized in RAG tasks. \n
    The task is the following: given a document chunk, you will have to
    generate questions that can be asked by a user to retrieve information from
    a large documentary corpus. \n
    The question should be relevant to the chunk, and should not be too specific
    or too general. The question should be about the subject of the chunk, and
    the answer needs to be found in the chunk. \n

    Remember that the question is asked by a user to get some information from a
    large documentary corpus. \n

    Generate a question that could be asked by a user without knowing the existence and the content of the corpus. \n
    Also generate the answer to the question, which should be found in the
    document chunk.  \n
    Generate TWO pairs of questions and answers per chunk in a
    dictionary with the following format, your answer should ONLY contain this dictionary, NOTHING ELSE: \n
    {
        "questions": [
            {
                "question": "XXXXXX",
                "answer": "YYYYYY",
            },
            {
                "question": "XXXXXX",
                "answer": "YYYYYY",
            },
        ]
    }
    where XXXXXX is the question, YYYYYY is the corresponding answers that could be as long as needed. \n
    Note: If there are no questions to ask about the chunk, return an empty list.
    Focus on making relevant questions concerning the page. \n
    Here is the chunk: \n
"""
    
    messages = [
        {
            'role': 'user',
            'content': prompt + element['text']
        }
    ] 
    response = ollama.chat(model='llama3.1:8b', messages=messages)
    
    return convert_qa_string_to_dict(response['message']['content'], element['element_id'], element['text'])    
    

In [67]:
def generate_qa_pairs_dataset(elements):
    dataset = []
    for el in elements: 
        dataset.extend(generate_chunk_qa_pairs(el))
    return dataset

Finally, let's generate the dataset. 
Running the following cell can take a long time depending on your hardware, model you use, and how large your documents are. You may also see a few JSON parsing errors, that's ok, that means that some LLM responses were not a correct JSON. In our experiments, there was a negligible amount of them.

In [None]:
eval_dataset = generate_qa_pairs_dataset(elements)

Let's save the dataset as a CSV file locally.

In [71]:
def save_dataset_as_csv(dict_list: list[dict], output_file: str):
    """
    Saves a list of dictionaries with QA pairs as a CSV file.
    """
    
    df = pd.DataFrame(dict_list)
    df = df[df['question'].notna()]
    df.to_csv(output_file, index=False)
    print(f"DataFrame saved to {output_file}")

save_dataset_as_csv(eval_dataset, "qa_pairs_dataset.csv")

DataFrame saved to qa_pairs_dataset.csv


## Set up retrievers and collect responses to questions

Now that we have an evaluation dataset, we can set up a retriever with each of the embedding models, and retrieve results for each of the question in the evaluation dataset - `evaluate_retriever.py`does just that. 

When running the script, pass the `--model_name`, and you can configure the following parameters:
* `--n_documents_to_retrieve'`: how many similar documents should the retriever return
* `--documents`: location of the processed documents to set up the retriever with 
* `--qa_dataset`: location of the evaluation dataset

The script will save the results into a local CSV file.    

In [215]:
!python evaluate_retriever.py  --n_documents_to_retrieve 10 --model_name "BAAI/bge-large-en-v1.5"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



>> from langchain.embeddings import HuggingFaceEmbeddings

with new imports of:

>> from langchain_community.embeddings import HuggingFaceEmbeddings
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(

>> from langchain.vectorstores import filter_complex_metadata

with new imports of:

>> from langchain_community.vectorstores.utils import filter_complex_metadata
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(
  warn_deprecated(
DataFrame saved to Retriever-BAAI-bge-large-en-v1.5-10_results.csv


In [216]:
!python evaluate_retriever.py  --n_documents_to_retrieve 10 --model_name "mukaj/fin-mpnet-base"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



>> from langchain.embeddings import HuggingFaceEmbeddings

with new imports of:

>> from langchain_community.embeddings import HuggingFaceEmbeddings
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(

>> from langchain.vectorstores import filter_complex_metadata

with new imports of:

>> from langchain_community.vectorstores.utils import filter_complex_metadata
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(
  warn_deprecated(
DataFrame saved to Retriever-mukaj-fin-mpnet-base-10_results.csv


In [221]:
!python evaluate_retriever.py  --n_documents_to_retrieve 10 --model_name "Snowflake/snowflake-arctic-embed-l"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



>> from langchain.embeddings import HuggingFaceEmbeddings

with new imports of:

>> from langchain_community.embeddings import HuggingFaceEmbeddings
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(

>> from langchain.vectorstores import filter_complex_metadata

with new imports of:

>> from langchain_community.vectorstores.utils import filter_complex_metadata
You can use the langchain cli to **automatically** upgrade many imports. Please see documentation here <https://python.langchain.com/v0.2/docs/versions/v0_2/>
  warn_deprecated(
  warn_deprecated(
DataFrame saved to Retriever-Snowflake-snowflake-arctic-embed-l-10_results.csv


## Calculate the metrics and compare the results

Once you have the results from each of the retrievers, let's calculate some metrics. 
In this example, we'll use two metrics: Recall@K, and MRR. 

Since the evaluation dataset has one relevant chunk per question, the average Recall@K will tell us how often this chunk was retrieved _at all_ in the K retrieved documents. The value of 1 would mean that we retrieved the relevant chunk for every question (without taking into account its position in the list of retrieved chunks), the value of 0 would mean that the relevant chunk was never retrieved for any question.

The average MRR (Mean reciprocal rank) will tell us the average position of the relevant chunk in the list of retrieved chunks, e.g. mrr = 1 would mean it was always the first result, mrr = 1/2 would mean it was second, etc.    

In [224]:
def calculate_retrieval_metrics(evaluation_data: pd.DataFrame, retrieval_results: pd.DataFrame, k = 10):
    eval_list = evaluation_data.to_dict('records')
    retrieval_list = retrieval_results.to_dict('records')

    recall_at_k = []
    reciprocal_ranks = []

    for item in retrieval_list:
        question = item["question"]
        
        retrieved_ids = eval(item["retrieved_ids"])[:k]
        
        for eval_point in eval_list:
            if eval_point['question'] == question:
                correct_id = eval_point["id"]
                continue

        if correct_id in retrieved_ids:
            rank = retrieved_ids.index(correct_id) + 1
            reciprocal_ranks.append(1 / rank)
            recall_at_k.append(1)  
        else:
            reciprocal_ranks.append(0)
            recall_at_k.append(0)  
        
    # Calculate average metrics

    # Recall@K: (number of relevant items in K results)/(total number of relevant items)
    # Sine we only have 1 relevant id in the eval dataset, the average recall@k will indicate 
    # how often this id was retrieved _at all_ in the 10 retrieved documents
    avg_recall_at_k = sum(recall_at_k) / len(retrieval_list)
    # How close to the top, on average, the relevant id was in the retrieved list of ids
    mrr = sum(reciprocal_ranks) / len(retrieval_list)
    metrics = {
        'Recall@K': avg_recall_at_k,
        'MRR': mrr, 
    }
        
    return metrics

In [225]:
eval_dataset = pd.read_csv("qa_pairs_dataset.csv")

In [226]:
r1_results = pd.read_csv("Retriever-BAAI-bge-large-en-v1.5-10_results.csv")
print("Metrics for BAAI/bge-large-en-v1.5:")
calculate_retrieval_metrics(eval_dataset, r1_results)

Metrics for BAAI/bge-large-en-v1.5:


{'Recall@K': 0.8893617021276595, 'MRR': 0.6491239821381656}

In [227]:
r2_results = pd.read_csv("Retriever-Snowflake-snowflake-arctic-embed-l-10_results.csv")
print("Metrics for Snowflake/snowflake-arctic-embed-l:")
calculate_retrieval_metrics(eval_dataset, r2_results)

Metrics for Snowflake/snowflake-arctic-embed-l:


{'Recall@K': 0.35981087470449175, 'MRR': 0.2144581410184246}

In [228]:
r3_results = pd.read_csv("Retriever-mukaj-fin-mpnet-base-10_results.csv")
print("Metrics for mukaj/fin-mpnet-base:")
calculate_retrieval_metrics(eval_dataset, r3_results)

Metrics for mukaj/fin-mpnet-base:


{'Recall@K': 0.8416075650118203, 'MRR': 0.5666567225787077}

## Interpret the results and pick the best baseline model 
Retriever_1 shows the best results: 

* 87.8% of the time the relevant id is retrieved. 
* On average, the relevant id is somewhere around 5th or 6th place in the list of retrieved documents.

## Complete the preprocessing pipeline with an embedding and upload steps

The results of partitioning and chunking are already cached, so by adding an embedding configuration to the pipeline we the pipeline will pick up at the embedding step, and won't re-process the documents from scratch.

In [219]:
Pipeline.from_configs(
    context=ProcessorConfig(
        verbose=True,
        tqdm=True,
        num_processes=20,
    ),
    indexer_config=LocalIndexerConfig(input_path="PDFS"),
    downloader_config=LocalDownloaderConfig(),
    source_connection_config=LocalConnectionConfig(),
    partitioner_config=PartitionerConfig(
        partition_by_api=True,
        api_key=os.getenv("UNSTRUCTURED_API_KEY"),
        partition_endpoint=os.getenv("UNSTRUCTURED_URL"),
        strategy="fast",
        ),
    chunker_config=ChunkerConfig(
        chunking_strategy="by_title",
        chunk_max_characters=1500,
        chunk_overlap = 150,
        ),
    embedder_config=EmbedderConfig(
        embedding_provider="langchain-huggingface",
        embedding_model_name="BAAI/bge-large-en-v1.5", # Adding the best embedding model
    ),
    uploader_config=LocalUploaderConfig(output_dir="embedded-outputs") # Changing the output location
).run()

2024-08-14 10:48:23,227 MainProcess INFO     Created index with configs: {"input_path": "PDFS", "recursive": false}, connection configs: {"access_config": {}}
2024-08-14 10:48:23,228 MainProcess INFO     Created download with configs: {"download_dir": null}, connection configs: {"access_config": {}}
2024-08-14 10:48:23,229 MainProcess INFO     Created partition with configs: {"strategy": "fast", "ocr_languages": null, "encoding": null, "additional_partition_args": null, "skip_infer_table_types": null, "fields_include": ["element_id", "text", "type", "metadata", "embeddings"], "flatten_metadata": false, "metadata_exclude": [], "metadata_include": [], "partition_endpoint": "https://api.unstructuredapp.io/general/v0/general", "partition_by_api": true, "api_key": "*******", "hi_res_model_name": null}
2024-08-14 10:48:23,230 MainProcess INFO     Created chunk with configs: {"chunking_strategy": "by_title", "chunking_endpoint": "https://api.unstructured.io/general/v0/general", "chunk_by_api"