# Finetune Embeddings

In this notebook, we show users how to finetune their own embedding models.

We go through three main sections:
1. Preparing the data (our `generate_qa_embedding_pairs` function makes this easy)
2. Finetuning the model (using our `SentenceTransformersFinetuneEngine`)

## Generate Corpus

First, we create the corpus of text chunks by leveraging LlamaIndex to load some financial PDFs, and parsing/chunking into plain text chunks.

In [1]:
# %pip install llama-index-llms-openai
# %pip install llama-index-embeddings-openai
# %pip install llama-index-finetuning

In [2]:
import json

from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import MetadataMode

Process Data

In [4]:
import os
folder_path = "./data/"
TRAIN_FILES = [folder_path + i for i in os.listdir(folder_path)][:-1]
VAL_FILES = [[folder_path + i for i in os.listdir(folder_path)][-1]]

TRAIN_CORPUS_FPATH = "./data/train_corpus.json"
VAL_CORPUS_FPATH = "./data/val_corpus.json"

In [5]:
def load_corpus(files, verbose=False):
    if verbose:
        print(f"Loading files {files}")

    reader = SimpleDirectoryReader(input_files=files)
    docs = reader.load_data()
    if verbose:
        print(f"Loaded {len(docs)} docs")

    parser = SentenceSplitter()
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

    if verbose:
        print(f"Parsed {len(nodes)} nodes")

    return nodes

We do a very naive train/val split by having the Lyft corpus as the train dataset, and the Uber corpus as the val dataset.

In [6]:
train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

Loading files ['./data/2305.14314v1.pdf', './data/2405.20202v1.pdf', './data/1706.03762v5.pdf', './data/2106.09685v2.pdf']
Loaded 78 docs


Parsing nodes:   0%|          | 0/78 [00:00<?, ?it/s]

Parsed 106 nodes
Loading files ['./data/2405.15731v1.pdf']
Loaded 22 docs


Parsing nodes:   0%|          | 0/22 [00:00<?, ?it/s]

Parsed 39 nodes


### Generate synthetic queries

Now, we use an LLM (llama3-8b-8192) to generate questions using each text chunk in the corpus as context.

Each pair of (generated question, text chunk used as context) becomes a datapoint in the finetuning dataset (either for training or evaluation).

In [7]:
from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

In [8]:
from dotenv import load_dotenv

load_dotenv()

True

In [15]:
from llama_index.llms.groq import Groq


train_dataset = generate_qa_embedding_pairs(
    llm=Groq(model="llama3-8b-8192",api_key=os.environ["GROQ_API_KEY"]), nodes=train_nodes
)
val_dataset = generate_qa_embedding_pairs(
    llm=Groq(model="llama3-8b-8192",api_key=os.environ["GROQ_API_KEY"]), nodes=val_nodes
)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 106/106 [04:20<00:00,  2.46s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:52<00:00,  2.89s/it]


In [9]:
### [Optional] Load
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

## Run Embedding Finetuning

In [10]:
from llama_index.finetuning import SentenceTransformersFinetuneEngine

In [None]:
finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="sentence-transformers/all-MiniLM-L6-v2",
    model_output_path="minilm-finetuned",
    val_dataset=val_dataset,
)

In [12]:
finetune_engine.finetune()

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/22 [00:00<?, ?it/s]

Iteration:   0%|          | 0/22 [00:00<?, ?it/s]

In [13]:
embed_model = finetune_engine.get_finetuned_model()