<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/docs/examples/finetuning/embeddings/finetune_embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Finetune Embeddings

In this notebook, we show users how to finetune their own embedding models.

We go through three main sections:
1. Preparing the data (our `generate_qa_embedding_pairs` function makes this easy)
2. Finetuning the model (using our `SentenceTransformersFinetuneEngine`)
3. Evaluating the model on a validation knowledge corpus

## Generate Corpus

First, we create the corpus of text chunks by leveraging LlamaIndex to load some financial PDFs, and parsing/chunking into plain text chunks.

In [1]:
%pip install llama-index-llms-openai
%pip install llama-index-embeddings-openai
%pip install llama-index-finetuning
%pip install llama-index-readers-file
%pip install datasets
%pip install llama-index-embeddings-huggingface

Collecting llama-index-llms-openai
  Downloading llama_index_llms_openai-0.3.18-py3-none-any.whl.metadata (3.3 kB)
Collecting llama-index-core<0.13.0,>=0.12.4 (from llama-index-llms-openai)
  Downloading llama_index_core-0.12.16.post1-py3-none-any.whl.metadata (2.5 kB)
Collecting dataclasses-json (from llama-index-core<0.13.0,>=0.12.4->llama-index-llms-openai)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting dirtyjson<2.0.0,>=1.0.8 (from llama-index-core<0.13.0,>=0.12.4->llama-index-llms-openai)
  Downloading dirtyjson-1.0.8-py3-none-any.whl.metadata (11 kB)
Collecting filetype<2.0.0,>=1.2.0 (from llama-index-core<0.13.0,>=0.12.4->llama-index-llms-openai)
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting tiktoken>=0.3.3 (from llama-index-core<0.13.0,>=0.12.4->llama-index-llms-openai)
  Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting typing-inspect>=0.8.0 (fro

In [None]:
import json

from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode


In [62]:
# [Optional] Load
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
train_dataset = EmbeddingQAFinetuneDataset.from_json("qa_train_9602.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("qa_val_2670.json")

## Run Embedding Finetuning

In [None]:
from llama_index.finetuning import SentenceTransformersFinetuneEngine

In [67]:
finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="BAAI/bge-small-zh-v1.5",
    model_output_path="fine_tuned_model",
    val_dataset=val_dataset,
)

In [68]:
finetune_engine.finetune()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Cosine Accuracy@1,Cosine Accuracy@3,Cosine Accuracy@5,Cosine Accuracy@10,Cosine Precision@1,Cosine Precision@3,Cosine Precision@5,Cosine Precision@10,Cosine Recall@1,Cosine Recall@3,Cosine Recall@5,Cosine Recall@10,Cosine Ndcg@10,Cosine Mrr@10,Cosine Map@100
50,No log,No log,0.796629,0.907116,0.935581,0.9603,0.796629,0.302372,0.187116,0.09603,0.796629,0.907116,0.935581,0.9603,0.882044,0.856546,0.858302
100,No log,No log,0.798876,0.905993,0.935955,0.964045,0.798876,0.301998,0.187191,0.096404,0.798876,0.905993,0.935955,0.964045,0.884266,0.858364,0.859999
150,No log,No log,0.797753,0.90824,0.939326,0.965918,0.797753,0.302747,0.187865,0.096592,0.797753,0.90824,0.939326,0.965918,0.884595,0.858202,0.859827
200,No log,No log,0.791386,0.901873,0.938577,0.966292,0.791386,0.300624,0.187715,0.096629,0.791386,0.901873,0.938577,0.966292,0.88142,0.853888,0.855551
250,No log,No log,0.801124,0.910112,0.941573,0.96779,0.801124,0.303371,0.188315,0.096779,0.801124,0.910112,0.941573,0.96779,0.887329,0.861189,0.862716
300,No log,No log,0.797378,0.90824,0.937828,0.967041,0.797378,0.302747,0.187566,0.096704,0.797378,0.90824,0.937828,0.967041,0.884694,0.858003,0.859651
350,No log,No log,0.798876,0.916854,0.941573,0.969663,0.798876,0.305618,0.188315,0.096966,0.798876,0.916854,0.941573,0.969663,0.88826,0.861738,0.863231
400,No log,No log,0.806742,0.917228,0.948689,0.969663,0.806742,0.305743,0.189738,0.096966,0.806742,0.917228,0.948689,0.969663,0.891983,0.866606,0.86811
450,No log,No log,0.807491,0.920599,0.944569,0.971536,0.807491,0.306866,0.188914,0.097154,0.807491,0.920599,0.944569,0.971536,0.893152,0.867603,0.868972
500,0.063100,No log,0.805993,0.912734,0.94382,0.970787,0.805993,0.304245,0.188764,0.097079,0.805993,0.912734,0.94382,0.970787,0.891902,0.866197,0.867603


In [69]:
embed_model = finetune_engine.get_finetuned_model()

In [70]:
embed_model

HuggingFaceEmbedding(model_name='fine_tuned_model', embed_batch_size=10, callback_manager=<llama_index.core.callbacks.base.CallbackManager object at 0x7c8ac9dc6050>, num_workers=None, max_length=512, normalize=True, query_instruction=None, text_instruction=None, cache_folder=None)

## Evaluate Finetuned Model

In this section, we evaluate 3 different embedding models:
1. proprietary OpenAI embedding,
2. open source `BAAI/bge-small-en`, and
3. our finetuned embedding model.

We consider 2 evaluation approaches:
1. a simple custom **hit rate** metric
2. using `InformationRetrievalEvaluator` from sentence_transformers

We show that finetuning on synthetic (LLM-generated) dataset significantly improve upon an opensource embedding model.

In [None]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd

### Define eval function

**Option 1**: We use a simple **hit rate** metric for evaluation:
* for each (query, relevant_doc) pair,
* we retrieve top-k documents with the query,  and
* it's a **hit** if the results contain the relevant_doc.

This approach is very simple and intuitive, and we can apply it to both the proprietary OpenAI embedding as well as our open source and fine-tuned embedding models.

In [None]:
def evaluate(
    dataset,
    embed_model,
    top_k=5,
    verbose=False,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
    index = VectorStoreIndex(
        nodes, embed_model=embed_model, show_progress=True
    )
    retriever = index.as_retriever(similarity_top_k=top_k)

    eval_results = []
    for query_id, query in tqdm(queries.items()):
        retrieved_nodes = retriever.retrieve(query)
        retrieved_ids = [node.node.node_id for node in retrieved_nodes]
        expected_id = relevant_docs[query_id][0]
        is_hit = expected_id in retrieved_ids  # assume 1 relevant doc

        eval_result = {
            "is_hit": is_hit,
            "retrieved": retrieved_ids,
            "expected": expected_id,
            "query": query_id,
        }
        eval_results.append(eval_result)
    return eval_results

**Option 2**: We use the `InformationRetrievalEvaluator` from sentence_transformers.

This provides a more comprehensive suite of metrics, but we can only run it against the sentencetransformers compatible models (open source and our finetuned model, *not* the OpenAI embedding model).

In [None]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformer
from pathlib import Path


def evaluate_st(
    dataset,
    model_id,
    name,
):
    corpus = dataset.corpus
    queries = dataset.queries
    relevant_docs = dataset.relevant_docs

    evaluator = InformationRetrievalEvaluator(
        queries, corpus, relevant_docs, name=name
    )
    model = SentenceTransformer(model_id)
    output_path = "results/"
    Path(output_path).mkdir(exist_ok=True, parents=True)
    return evaluator(model, output_path=output_path)

### Run Evals

### BAAI/bge-m3

In [None]:
bge_m3 = "local:BAAI/m3"
bge_m3_val_results = evaluate(val_dataset, bge_m3)

In [None]:
df_bge_m3 = pd.DataFrame(bge_m3_val_results)

In [None]:
hit_rate_bge_m3 = df_bge_m3["is_hit"].mean()
hit_rate_bge_m3

In [None]:
evaluate_st(val_dataset, "BAAI/bge-m3", name="bge-m3")

### BAAI/bge-small-zh-v1.5

In [63]:
bge = "local:BAAI/bge-small-zh-v1.5"
bge_val_results = evaluate(val_dataset, bge)

Generating embeddings:   0%|          | 0/1335 [00:00<?, ?it/s]

  0%|          | 0/2670 [00:00<?, ?it/s]

In [64]:
df_bge = pd.DataFrame(bge_val_results)

In [65]:
hit_rate_bge = df_bge["is_hit"].mean()
hit_rate_bge

0.9108614232209737

In [66]:
evaluate_st(val_dataset, "BAAI/bge-small-zh-v1.5", name="bge-small-zh-v1.5")

{'bge_cosine_accuracy@1': 0.749438202247191,
 'bge_cosine_accuracy@3': 0.8756554307116104,
 'bge_cosine_accuracy@5': 0.9063670411985019,
 'bge_cosine_accuracy@10': 0.9374531835205993,
 'bge_cosine_precision@1': 0.749438202247191,
 'bge_cosine_precision@3': 0.2918851435705368,
 'bge_cosine_precision@5': 0.18127340823970034,
 'bge_cosine_precision@10': 0.0937453183520599,
 'bge_cosine_recall@1': 0.749438202247191,
 'bge_cosine_recall@3': 0.8756554307116104,
 'bge_cosine_recall@5': 0.9063670411985019,
 'bge_cosine_recall@10': 0.9374531835205993,
 'bge_cosine_ndcg@10': 0.8475168491500676,
 'bge_cosine_mrr@10': 0.8182826526365851,
 'bge_cosine_map@100': 0.8208434187476819}

### Finetuned

In [71]:
finetuned = "local:fine_tuned_model"
val_results_finetuned = evaluate(val_dataset, finetuned)

Generating embeddings:   0%|          | 0/1335 [00:00<?, ?it/s]

  0%|          | 0/2670 [00:00<?, ?it/s]

In [72]:
df_finetuned = pd.DataFrame(val_results_finetuned)

In [73]:
hit_rate_finetuned = df_finetuned["is_hit"].mean()
hit_rate_finetuned

0.951310861423221

In [74]:
evaluate_st(val_dataset, "test_model", name="fine-tuned bge-small-zh-v1.5")

{'finetuned_cosine_accuracy@1': 0.7812734082397004,
 'finetuned_cosine_accuracy@3': 0.898876404494382,
 'finetuned_cosine_accuracy@5': 0.9325842696629213,
 'finetuned_cosine_accuracy@10': 0.9617977528089887,
 'finetuned_cosine_precision@1': 0.7812734082397004,
 'finetuned_cosine_precision@3': 0.299625468164794,
 'finetuned_cosine_precision@5': 0.18651685393258424,
 'finetuned_cosine_precision@10': 0.09617977528089885,
 'finetuned_cosine_recall@1': 0.7812734082397004,
 'finetuned_cosine_recall@3': 0.898876404494382,
 'finetuned_cosine_recall@5': 0.9325842696629213,
 'finetuned_cosine_recall@10': 0.9617977528089887,
 'finetuned_cosine_ndcg@10': 0.8747614202066474,
 'finetuned_cosine_mrr@10': 0.8465115034777956,
 'finetuned_cosine_map@100': 0.8481814591456369}

### Summary of Results

#### Hit rate

In [75]:
df_bge["model"] = "bge-small-zh-v1.5"
df_finetuned["model"] = "fine_tuned bge-small-zh-v1.5"

We can see that fine-tuning our small open-source embedding model drastically improve its retrieval quality (even approaching the quality of the proprietary OpenAI embedding)!

In [76]:
df_all = pd.concat([ df_bge, df_finetuned])
df_all.groupby("model").mean("is_hit")

Unnamed: 0_level_0,is_hit
model,Unnamed: 1_level_1
bge-small-zh-v1.5,0.910861
fine_tuned bge-small-zh-v1.5,0.951311


#### InformationRetrievalEvaluator

In [81]:
df_st_bge = pd.read_csv(
    "results/Information-Retrieval_evaluation_bge_results.csv"
)
df_st_finetuned = pd.read_csv(
    "results/Information-Retrieval_evaluation_finetuned_results.csv"
)

FileNotFoundError: [Errno 2] No such file or directory: 'results/Information-Retrieval_evaluation_bge_results_new.csv'

We can see that embedding finetuning improves metrics consistently across the suite of eval metrics

In [80]:
df_st_bge["model"] = "bge-small-zh-v1.5"
df_st_finetuned["model"] = "fine_tuned bge-small-zh-v1.5"
df_st_all = []
df_st_all = pd.concat([df_st_bge, df_st_finetuned])
df_st_all = df_st_all.set_index("model")
df_st_all

Unnamed: 0_level_0,epoch,steps,cosine-Accuracy@1,cosine-Accuracy@3,cosine-Accuracy@5,cosine-Accuracy@10,cosine-Precision@1,cosine-Recall@1,cosine-Precision@3,cosine-Recall@3,cosine-Precision@5,cosine-Recall@5,cosine-Precision@10,cosine-Recall@10,cosine-MRR@10,cosine-NDCG@10,cosine-MAP@100
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
bge-small-zh-v1.5,-1,-1,0.483,0.647,0.723,0.8,0.483,0.483,0.215667,0.647,0.1446,0.723,0.08,0.8,0.583477,0.635473,0.590388
bge-small-zh-v1.5,-1,-1,0.483,0.647,0.723,0.8,0.483,0.483,0.215667,0.647,0.1446,0.723,0.08,0.8,0.583477,0.635473,0.590388
bge-small-zh-v1.5,-1,-1,0.47381,0.631746,0.690079,0.766667,0.47381,0.47381,0.210582,0.631746,0.138016,0.690079,0.076667,0.766667,0.56651,0.614538,0.573082
bge-small-zh-v1.5,-1,-1,0.749438,0.875655,0.906367,0.937453,0.749438,0.749438,0.291885,0.875655,0.181273,0.906367,0.093745,0.937453,0.818283,0.847517,0.820843
fine_tuned bge-small-zh-v1.5,-1,-1,0.503175,0.679762,0.744048,0.803175,0.503175,0.503175,0.226587,0.679762,0.14881,0.744048,0.080317,0.803175,0.604404,0.652645,0.611003
fine_tuned bge-small-zh-v1.5,-1,-1,0.593254,0.79246,0.851587,0.914683,0.593254,0.593254,0.264153,0.79246,0.170317,0.851587,0.091468,0.914683,0.70441,0.75557,0.708655
fine_tuned bge-small-zh-v1.5,-1,-1,0.593254,0.79246,0.851587,0.914683,0.593254,0.593254,0.264153,0.79246,0.170317,0.851587,0.091468,0.914683,0.70441,0.75557,0.708655
fine_tuned bge-small-zh-v1.5,-1,-1,0.593254,0.79246,0.851587,0.914683,0.593254,0.593254,0.264153,0.79246,0.170317,0.851587,0.091468,0.914683,0.70441,0.75557,0.708655
fine_tuned bge-small-zh-v1.5,-1,-1,0.781273,0.898876,0.932584,0.961798,0.781273,0.781273,0.299625,0.898876,0.186517,0.932584,0.09618,0.961798,0.846512,0.874761,0.848181


In [82]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!zip -r results.zip results

  adding: results/ (stored 0%)
  adding: results/Information-Retrieval_evaluation_finetuned_results.csv (deflated 62%)
  adding: results/Information-Retrieval_evaluation_bge_results.csv (deflated 64%)
