# Introduction

In this notebook, we'll look at benchmarking the retrieval performance of embedding search given a user query. We'll be using the synthetic questions that we generated in the previous notebook to do so.

We'll do so in two steps

1. First, we'll take the same dataset that we used previously to generate the synthetic questions and ingest it into a local lancedb instance
2. Then we'll show you how to measure recall and MRR at different levels of k for our synthetic questions

## Data Ingestion

We'll start by using `lancedb` to ingest the dataset into a local database. This integrates nicely with `Pydantic` and also handles embedding and batching of of all of our data. Let's start by creating a new lancedb database and defining our Pydantic model for the chunks.

Note here that we're using the `LanceModel` class from `lancedb` to define our Pydantic Model. This is a subclass of the BaseModel that adds additional functionality for working with LanceDB.

In [14]:
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry

db = lancedb.connect("./lancedb")

func = get_registry().get("openai").create(name="text-embedding-3-small")

class Chunk(LanceModel):
    id:str
    text: str = func.SourceField()
    vector: Vector(func.ndims()) = func.VectorField()

Now that we've defined our model and created our database. Let's start by ingesting our dataset into the database. We'll define a function to automatically batch our data, create a new table using the Pydantic model above and then ingest all of our chunks into the database

In [23]:
import datasets
import hashlib


def batch_items(items:list[Chunk],batch_size:int=100):
    cur = []
    for item in items:
        cur.append(item)
        if len(cur) == batch_size:
            yield cur
            cur = []
    if cur:
        yield cur

def hash_query(query:str) -> str:
    return hashlib.sha256(query.encode()).hexdigest()
    
dataset = datasets.load_dataset("567-labs/bird-dev-snippets")
formatted_dataset = [
    {
        "text":item["query"],
        "id":hash_query(item["query"])
    }
    for item in dataset["original"]
]


Now that we've created our batches, we can ingest them into the databse. We'll do so by creating a new table and then inserting each batch into the table.

In [24]:
from tqdm import tqdm

table = db.create_table("chunks",schema=Chunk,mode="overwrite")

batches = batch_items(formatted_dataset,300)

for batch in tqdm(batches):
    table.add(batch)

6it [00:13,  2.29s/it]


## Evaluation

Let's now start by evaluating the retrieval performance of our model. We'll do so by measuring the recall and MRR at different levels of k.

In [43]:
def calculate_mrr(predictions: list[str], gt: list[str]):
    mrr = 0
    for label in gt:
        if label in predictions:
            mrr = max(mrr, 1 / (predictions.index(label) + 1))
    return mrr


def calculate_recall(predictions: list[str], gt: list[str]):
    return len([label for label in gt if label in predictions]) / len(gt)
    

### Braintrust

Now let's see how we can compute these using braintrust and lancedb. Let's start by writing a function that will take a query and return the top k results from our database.

In [45]:
def retrieve(query:str,k:int=10) -> list[str]:
    results = table.search(query).limit(k).to_list()
    return [result["text"] for result in results]

retrieve("What is the capital of France?")[:2]

["SELECT T1.driverId FROM lapTimes AS T1 INNER JOIN races AS T2 on T1.raceId = T2.raceId WHERE T2.name = 'French Grand Prix' AND T1.lap = 3 ORDER BY T1.time DESC LIMIT 1",
 'SELECT T2.City FROM frpm AS T1 INNER JOIN schools AS T2 ON T1.CDSCode = T2.CDSCode GROUP BY T2.City ORDER BY SUM(T1.`Enrollment (K-12)`) ASC LIMIT 5']

Let's now define a simple list of functions that will take in these returned results and compute the recall and MRR

In [44]:
import itertools

eval_metrics = [["mrr", calculate_mrr], ["recall", calculate_recall]]
sizes = [3, 5, 10, 15, 25]

metrics = {
    f"{metric_name}@{size}": lambda predictions, gt, m=metric_fn, s=size: (
        lambda p, g: m(p[:s], g)
    )(predictions, gt)
    for (metric_name, metric_fn), size in itertools.product(eval_metrics, sizes)
}

In [46]:
from braintrust import Score, init_dataset, Eval

def evaluate_braintrust(input, output, **kwargs):
    hashed_queries = [hash_query(query) for query in output]
    hashed_expected = [hash_query(query) for query in kwargs["expected"]]
    return [
        Score(
            name=metric,
            score=score_fn(hashed_queries,hashed_expected),
            metadata={"query": input, "result": output, **kwargs["metadata"]},
        )
        for metric, score_fn in metrics.items()
        
    ]

await Eval(
    "Retrieval",
    data=init_dataset(project="Retrieval", name="Synthetic Questions"),

    task=lambda input: retrieve(input,25),
    scores=[evaluate_braintrust],
)

Experiment add-braintrust-support-1729509406 is running at https://www.braintrust.dev/app/567/p/Retrieval/experiments/add-braintrust-support-1729509406
Retrieval (data): 145it [00:01, 116.71it/s]
Retrieval (tasks): 100%|██████████| 145/145 [00:09<00:00, 15.39it/s]



add-braintrust-support-1729509406 compared to add-braintrust-support-1729509191:
76.78% (-01.93%) 'mrr@3'     score	(0 improvements, 20 regressions)
77.57% (-01.13%) 'mrr@5'     score	(0 improvements, 15 regressions)
78.37% (-00.34%) 'mrr@10'    score	(0 improvements, 6 regressions)
78.66% (-00.04%) 'mrr@15'    score	(0 improvements, 1 regressions)
78.71% (-) 'mrr@25'    score	(0 improvements, 0 regressions)
86.21% (-13.79%) 'recall@3'  score	(0 improvements, 20 regressions)
89.66% (-10.34%) 'recall@5'  score	(0 improvements, 15 regressions)
95.86% (-04.14%) 'recall@10' score	(0 improvements, 6 regressions)
99.31% (-00.69%) 'recall@15' score	(0 improvements, 1 regressions)
100.00% (-) 'recall@25' score	(0 improvements, 0 regressions)

4.53s (-92.24%) 'duration'	(119 improvements, 26 regressions)

See results for add-braintrust-support-1729509406 at https://www.braintrust.dev/app/567/p/Retrieval/experiments/add-braintrust-support-1729509406


EvalResultWithSummary(summary="...", results=[...])