# Generate data for training custom embedding and reranking models

Bootstrapping and maintaining production-ready RAG pipelines, requires optimising various components like the LLM, vector database, embedders and rerankers.

In this tutorial, we will showcase how you can optimize and maintain your embedding models and rerankers through synthetic data and human feedback. In this case, using the `GenerateSentencePair` task, and the `sentence-transformers` library.

We will follow these steps:

- The dataset
- Synthetic data generation
  - retrieval
  - reranking
  - combined pipeline
- Data quality evaluation 
  - feature engineering
  - (optional) Argilla 
- Fine-tuning
  - retrieval
  - reranking

## Getting started

### Install the dependencies

To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. You can install them by running the following command:

In [None]:
!pip install "distilabel[openai]"

In [None]:
!pip install "sentence-transformers"

Let's make the needed imports:



In [2]:
import os

from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import GenerateSentencePair
from distilabel.steps import LoadDataFromHub

from sentence_transformers import SentenceTransformer, CrossEncoder
import torch

  from .autonotebook import tqdm as notebook_tqdm



### (Optional) deploy Argilla

You can skip this step or replace it with any other data evaluation tool but the quality of your model will suffer from a lack of data quality so we do recommend to look at your data. If you already have deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/). 

Allong with that, you will need to install argilla as distilabel extra.

!pip install "distilabel[argilla, openai]"

Let's make the extra needed imports:

In [21]:
import argilla as rg

## The dataset

Before starting any project, it is always important to look at your data. Our data is publicly available [on the Hugging Face Hub]([plaguss/argilla_sdk_docs_raw_unstructured](https://huggingface.co/datasets/plaguss/argilla_sdk_docs_raw_unstructured?row=0)) so we can have a quick look through [their dataset viewer within an embedded iFrame](https://huggingface.co/docs/hub/datasets-viewer-embed). 

<iframe src="https://huggingface.co/datasets/plaguss/argilla_sdk_docs_raw_unstructured/embed/viewer" frameborder="0" width="100%" height="560px"></iframe>

As we can see, our dataset contains a column called `chunks`, which was obtained from the Argilla docs. Normally, you would need to download and chunk the data but we will not cover that in this tutorial. To read a full explanation for how this dataset was generated, please refer to [How we leveraged distilabel to create an Argilla 2.0 Chatbot](https://huggingface.co/blog/argilla-chatbot#downloading-and-chunking-data).

Alternatively, we can load the entire dataset to disk with `datasets.load_dataset`.

## Synthetic data generation

The [`GenerateSentencePair`](https://distilabel.argilla.io/latest/components-gallery/tasks/generatesentencepair/) component from `distilabel` can be used to generate training datasets for embeddings models. 

It is a pre-defined `Task` that given an `anchor` sentence generate data for a specific `action`. Supported actions are: `"paraphrase", "semantically-similar", "query", "answer"`. In our case the `chunks` column corresponds to the `anchor`. This means we will use `query` to generate potential queries for a fine-tuning a retrieval model and that we will use `semantically-similar` to generate texts that are similar to the intial anchor for fine-tuning a reranking model.

We will `triplet=True` in order to generate both positive and negative examples, which should help the model generalize better during fine-tuning and we will set `hard_negative=True` to generate more challenging examples that are closer to the anchor and discussed topics.

Lastly, we can seed the LLM with `context` to generate more relevant examples.

In [2]:
context = context = (
    """
The text is a chunk from technical Python SDK documentation of Argilla.
Argilla is a collaboration tool for AI engineers and domain experts to build high-quality datasets.
Along with prose explanations, the text chunk may include code snippets and Python references.
"""
)

### Retrieval

For retrieval, we will thus generate queries that are similar to the `chunks` column. We will use the `query` action to generate potential queries for a fine-tuning a retrieval model.

```python
generate_sentence_pair = GenerateSentencePair(
    triplet=True,  # `False` to generate only positive
    action="query",
    llm=llm,
    input_batch_size=10,
    context=context,
)
```

### Reranking

For reranking, we will generate texts that are similar to the intial anchor. We will use the `semantically-similar` action to generate texts that are similar to the intial anchor for fine-tuning a reranking model.

```python
generate_sentence_pair = GenerateSentencePair(
    triplet=True,
    hard_negative=True,
    action="semantically-similar",
    llm=llm,
    input_batch_size=10,
    context=context,
)
```

### Combined pipeline

We will now use the `GenerateSentencePair` task to generate synthetic data for both retrieval and reranking models in a single pipeline. Note that, we map the `chunks` column to the `anchor` argument.

In [16]:
llm = OpenAILLM(model="gpt-4o", api_key=os.getenv("OPENAI_API_KEY"))

with Pipeline(name="generate") as pipeline:
    load_dataset = LoadDataFromHub(
        num_examples=15,
        output_mappings={"chunks": "anchor"},
    )
    generate_retrieval_pairs = GenerateSentencePair(
        name="generate_retrieval_pairs",
        triplet=True,
        hard_negative=True,
        action="query",
        llm=llm,
        input_batch_size=10,
        context=context,
    )
    generate_reranking_pairs = GenerateSentencePair(
        name="generate_reranking_pairs",
        triplet=True,
        hard_negative=True,
        action="semantically-similar",
        llm=llm,
        input_batch_size=10,
        context=context,
    )

    load_dataset >> [generate_retrieval_pairs, generate_reranking_pairs]

Next, we can execute this using `pipeline.run`. We will provide some `parameters` to specific components within our pipeline.

In [None]:
generation_kwargs = {
    "llm": {
        "generation_kwargs": {
            "temperature": 0.7,
            "max_new_tokens": 512,
        }
    }
}

distiset = pipeline.run(  #
    parameters={
        load_dataset.name: {
            "repo_id": "plaguss/argilla_sdk_docs_raw_unstructured",
            "split": "train",
        },
        generate_retrieval_pairs.name: generation_kwargs,
        generate_reranking_pairs.name: generation_kwargs,
    },
    use_cache=False,  # comment out for demo
)

In [18]:
distiset

Distiset({
    generate_reranking_pairs: DatasetDict({
        train: Dataset({
            features: ['filename', 'anchor', 'repo_name', 'positive', 'negative', 'distilabel_metadata', 'model_name'],
            num_rows: 15
        })
    })
    generate_retrieval_pairs: DatasetDict({
        train: Dataset({
            features: ['filename', 'anchor', 'repo_name', 'positive', 'negative', 'distilabel_metadata', 'model_name'],
            num_rows: 15
        })
    })
})

We have got 2 different leaf/end nodes, therefore we've got a distil configurations we can access, one for the retrieval data, and one for the reranking data.

In [19]:
distiset["generate_reranking_pairs"]["train"][0]

{'filename': 'argilla-python/docs/index.md',
 'anchor': 'description: Argilla is a collaboration platform for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency.\nhide: navigation\n\nWelcome to Argilla\n\nArgilla is a collaboration platform for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency.',
 'repo_name': 'argilla-io/argilla-python',
 'positive': 'description: Argilla is a collaboration tool designed for AI engineers and domain experts who need high-quality outputs, full data control, and maximum efficiency.\nhide: navigation\n\nWelcome to Argilla\n\nArgilla is a collaboration tool designed for AI engineers and domain experts who need high-quality outputs, full data control, and maximum efficiency.',
 'negative': 'description: Argilla is a platform for marketing professionals and sales teams that prioritizes customer engagement, brand visibility, and revenue gro

In [20]:
distiset["generate_retrieval_pairs"]["train"][0]

{'filename': 'argilla-python/docs/index.md',
 'anchor': 'description: Argilla is a collaboration platform for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency.\nhide: navigation\n\nWelcome to Argilla\n\nArgilla is a collaboration platform for AI engineers and domain experts that require high-quality outputs, full data ownership, and overall efficiency.',
 'repo_name': 'argilla-io/argilla-python',
 'positive': 'What is Argilla and how does it benefit AI engineers and domain experts?',
 'negative': "How does Argilla's interface compare with other project management tools?",
 'distilabel_metadata': {'raw_output_generate_retrieval_pairs': "## Positive\n\nWhat is Argilla and how does it benefit AI engineers and domain experts?\n\n## Negative\n\nHow does Argilla's interface compare with other project management tools?"},
 'model_name': 'gpt-4o'}

Looking at these initial examples, we can see they nicely capture the essence of the `chunks` column but we will need to evaluate the quality of the data a bit more before we can use it for fine-tuning.

## Data quality evaluation 

Data is never as clean as it can be and this also holds for synthetically generated data too, therefore, it is always good to spent some time and look at your data.

### Feature engineering

In order to evaluate the quality of our data we will use features of the  models that we intent to fine-tune as proxy for data quality. We can then use these features to filter out the best examples.

In order to choose a good default model, we will use the [Massive Text Embedding Benchmark (MTEB) Leaderboard](https://huggingface.co/spaces/mteb/leaderboard). We want to optimize for size and speed, so we will set model size `<100M` and then filter for `Retrieval` and `Reranking` based on the highest average score, resulting in [Snowflake/snowflake-arctic-embed-s](https://huggingface.co/Snowflake/snowflake-arctic-embed-s) and [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2) respectively.

<iframe
	src="https://mteb-leaderboard.hf.space"
	frameborder="0"
	width="100%"
	height="600"
></iframe>

#### Retrieval

For retrieval, we will compute similarities for the current embeddings of `anchor-positive`, `positive-negative` and `anchor-negative` pairs. We assume that an overlap of these similarities will cause the model to have difficulties generalizing and therefore we can use these features to evaluate the quality of our data.

In [None]:
model_id = "Snowflake/snowflake-arctic-embed-m"  # Hugging Face model ID

model_retrieval = SentenceTransformer(
    model_id, device="cuda" if torch.cuda.is_available() else "cpu"
)

Next, we will encode the generated text pairs and compute the similarities. 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def get_embeddings(texts):
    vectors = model_retrieval.encode(texts)
    return [vector.tolist() for vector in vectors]


def get_similarities(vector_batch_a, vector_batch_b):
    similarities = []
    for vector_a, vector_b in zip(vector_batch_a, vector_batch_b):
        similarity = cosine_similarity([vector_a], [vector_b])[0][0]
        similarities.append(similarity)
    return similarities

def format_data_retriever(batch):# -> Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["positive-vector"] = get_embeddings(batch["positive"])
    batch["negative-vector"] = get_embeddings(batch["negative"])    
    batch["similarity-positive-negative"] = get_similarities(batch["positive-vector"], batch["negative-vector"])
    batch["similarity-anchor-positive"] = get_similarities(batch["anchor-vector"], batch["positive-vector"])
    batch["similarity-anchor-negative"] = get_similarities(batch["anchor-vector"], batch["negative-vector"])
    return batch

dataset_generate_retrieval_pairs = distiset["generate_retrieval_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)


#### Reranking

For reranking, we will compute the compute the relevance scores from an existing reranker model for `anchor-positive`, `positive-negative` and `anchor-negative` pais and make a similar assumption as for the retrieval model.

In [4]:
model_id = "sentence-transformers/all-MiniLM-L12-v2"

model = CrossEncoder(model_id)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L12-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Next, we will compute the similarity for the generated text pairs using the reranker. On top of that, we will compute an `anchor-vector` to allow for doing semantic search.

In [None]:
def format_data_retriever(batch):# -> Any:
    batch["anchor-vector"] = get_embeddings(batch["anchor"])
    batch["similarity-positive-negative"] = model.predict(zip(batch["positive-vector"], batch["negative-vector"]))
    batch["similarity-anchor-positive"] = model.predict(zip(batch["anchor-vector"], batch["positive-vector"]))
    batch["similarity-anchor-negative"] = model.predict(zip(batch["anchor-vector"], batch["negative-vector"]))
    return batch

dataset_generate_reranking_pairs = distiset["generate_reranking_pairs"]["train"].map(format_data_retriever, batched=True, batch_size=250)

And voila, we have our proxies for quality evaluation which we can use to filter out the best and worst examples.

### (Optional) Argilla

To get the most out of you data and actually look at our data, we will use Argilla. If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstart docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).

To start exploring data, we first need to define an `argilla.Dataset`. We will create a basic datset with some input `TextFields` for the `anchor` and output `TextQuestions` for the `positive` and `negative` pairs. Additionally, we will use the `file_name` as `MetaDataProperty`. Lastly, we will be re-using the vectors obtained from our previous step to allow for semantic search and we will add te similarity scores for some basic filtering and sorting.

First, we need to define the setting for our Argilla dataset. We will create two different datasets, one for the retrieval data and one for the reranking data to ensure our annotators can focus on the task at hand.

In [None]:
import argilla as rg
from argilla._exceptions import ConflictError

api_key = "ohh so secret"
api_url = "https://davidberenstein1957-my-argilla.hf.space"

client = rg.Argilla(api_url=api_url, api_key=api_key)

settings = rg.Settings(
    fields=[
        rg.TextField("anchor")
    ],
    questions=[
        rg.TextQuestion("positive"),
        rg.TextQuestion("negative")
    ],
    metadata=[
        rg.TermsMetadataProperty("parent_section"),
        rg.FloatMetadataProperty("similarity-positive-negative"),
        rg.FloatMetadataProperty("similarity-anchor-positive"),
        rg.FloatMetadataProperty("similarity-anchor-negative"),
    ],
    vectors=[
        rg.VectorField("anchor-vector", dimensions=model.get_sentence_embedding_dimension())
    ]
)
rg_datasets = []
for dataset_name in ["generate_retrieval_pairs", "generate_reranking_pairs"]:
    ds = rg.Dataset(
        name=dataset_name,
        settings=settings
    )
    try:
        ds.create()
    except ConflictError:
        ds = client.datasets(dataset_name)
    rg_datasets.append(ds)

Now, we've got our dataset definitions setup in Argilla, we can upload our data to Argilla.

In [None]:
ds_datasets = [dataset_generate_retrieval_pairs, dataset_generate_reranking_pairs]

records = []

for rg_dataset, ds_dataset in zip(rg_datasets, ds_datasets):
    for idx, entry in enumerate(ds_dataset):
        records.append(
            rg.Record(
                id=idx,
                fields={"anchor": entry["anchor"]},
                suggestions=[
                    rg.Suggestion("positive", value=entry["positive"], agent="gpt-4o", type="model"),
                    rg.Suggestion("negative", value=entry["negative"], agent="gpt-4o", type="model"),
                ],
                metadata={
                    "parent_section": entry["parent_section"],
                    "token_count": entry["token_count"],
                    "similarity-positive-negative": entry["similarity-positive-negative"],
                    "similarity-anchor-positive": entry["similarity-anchor-positive"],
                    "similarity-anchor-negative": entry["similarity-anchor-negative"]
                },
                vectors={"anchor-vector": entry["anchor-vector"]}
            )
        )
    rg_dataset.records.log(records)

Now we can explore the UI and add a final human touch to get he most out of our dataset. 