# OpenAI x Qdrant

## Binary Quantization with OpenAI Ada-003 Embeddings

This notebook demonstrates how to use Qdrant to index and search OpenAI Ada-003 embeddings. We will compare the production-like search performance of Qdrant with Binary Quantization with the brute-force search. We will use [Qdrant Cloud](https://qdrant.to/cloud?utm_source=qdrant&utm_medium=social&utm_campaign=binary-openai-v3&utm_content=article) to index and search the embeddings.

In [None]:
import json
import os

import loguru
import numpy as np
import random
from datasets import load_dataset
from datasets.exceptions import DatasetNotFoundError
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct
from tqdm import tqdm

load_dotenv()  # take environment variables from .env
logger = loguru.logger
logger.add("logs.log", format="{time} {level} {message}", level="INFO")

# Setup a Qdrant Client connection

We will use the `qdrant-client` python package to interact with Qdrant. You can install it with `pip install qdrant-client`. We manage our dependencies using Poetry, so you can install all the dependencies with `poetry install`.

In [None]:
client = QdrantClient(
    url=os.getenv("QDRANT_URL"),
    api_key=os.getenv("QDRANT_API_KEY"),
    timeout=100,
)

In [None]:
bs = 512

In [None]:
dataset_combinations = [
    # {
    #     "model_name": "text-embedding-3-large",
    #     "dimensions": 3072,
    # },
    # {
    #     "model_name": "text-embedding-3-large",
    #     "dimensions": 1024,
    # },
    # {
    #     "model_name": "text-embedding-3-large",
    #     "dimensions": 1536,
    # },
    {
        "model_name": "text-embedding-3-small",
        "dimensions": 512,
    },
    # {
    #     "model_name": "text-embedding-3-small",
    #     "dimensions": 1024,
    # },
    # {
    #     "model_name": "text-embedding-3-small",
    #     "dimensions": 1536,
    # },
]

In [10]:
for combination in dataset_combinations:
    MODEL_NAME, DIMENSIONS = combination["model_name"], combination["dimensions"]
    DATASET_NAME = f"Qdrant/dbpedia-entities-openai3-{MODEL_NAME}-{DIMENSIONS}-100K"
    collection_name = f"dbpedia-{MODEL_NAME}-{DIMENSIONS}"
    embedding_column_name = f"{MODEL_NAME}-{DIMENSIONS}-embedding"
    try:
        client.recreate_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(
                size=DIMENSIONS,
                distance=models.Distance.COSINE,
            ),
            optimizers_config=models.OptimizersConfigDiff(
                indexing_threshold=0,
                # Turn off indexing for faster upserts
            ),
            quantization_config=models.BinaryQuantization(
                binary=models.BinaryQuantizationConfig(always_ram=True),
            ),
            shard_number=2,
        )
    except Exception as e:
        collection_info = client.get_collection(collection_name=collection_name)
        logger.error(f"Collection {collection_name} already exists with {collection_info.points_count} points. {e}")
        continue
    logger.info(f"Created collection {collection_name}")
    try:
        dataset = load_dataset(
        DATASET_NAME,
        streaming=False,
        split="train",
    )
    except DatasetNotFoundError:
        logger.error(f"Dataset {DATASET_NAME} not found")
        continue
    logger.info(f"Loaded {DATASET_NAME} dataset")
    points = [
        {
            "id": i,
            "vector": embedding,
            "payload": {"text": data["text"], "title": data["title"]},
        }
        for i, (embedding, data) in enumerate(zip(dataset[embedding_column_name], dataset))
    ]
    points = [PointStruct(**point) for point in points]
    logger.info(f"Loaded {len(points)} points")
    
    collection_info = client.get_collection(collection_name=collection_name)
    if collection_info.vectors_count == 0:
        logger.info("Collection is empty. Begin upsert.")
        for i in tqdm(range(0, len(points), bs)):
            slice_points = points[i : i + bs]  # Create a slice of bs points
            client.upsert(
                collection_name=collection_name, points=slice_points, wait=True
            )
    # After the upsert, we can turn on indexing for faster search
    client.update_collection(
        collection_name=f"{collection_name}",
        optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
    )
    logger.info(f"Collection {collection_name} is ready")

[32m2024-02-06 12:17:09.274[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1mCreated collection dbpedia-text-embedding-3-small-512[0m
[32m2024-02-06 12:17:15.988[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m36[0m - [1mLoaded Qdrant/dbpedia-entities-openai3-text-embedding-3-small-512-100K dataset[0m
[32m2024-02-06 12:17:35.932[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m46[0m - [1mLoaded 100000 points[0m
[32m2024-02-06 12:17:36.688[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m50[0m - [1mCollection is empty. Begin upsert.[0m
100%|██████████| 196/196 [03:19<00:00,  1.02s/it]
[32m2024-02-06 12:20:56.677[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m61[0m - [1mCollection dbpedia-text-embedding-3-small-512 is ready[0m


> 💡 Note on Indexing: Since Indexing is a background process, it does not affect our exact search performance. Turning off the indexing allows for faster uploads and writes.

## Create a valuation split for the BQ to Compare exact with approximate

In [11]:
oversampling_range = np.arange(1.0, 3.1, 1.0)
rescore_range = [True, False]

# Parameterized Search

We will use the `qdrant-client` to perform a parameterized search. We will compare the exact search performance with the approximate search performance. We will use the `search` method to perform the exact search with exact=True.

In [12]:
def parameterized_search(
    point,
    oversampling: float,
    rescore: bool,
    exact: bool,
    collection_name: str,
    ignore: bool = False,
    limit: int = 10,
):
    if exact:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(exact=exact),
            limit=limit,
        )
    else:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=ignore,
                    rescore=rescore,
                    oversampling=oversampling,
                ),
                exact=exact,
            ),
            limit=limit,
        )


In [13]:
for combination in dataset_combinations:
    MODEL_NAME, DIMENSIONS = combination["model_name"], combination["dimensions"]
    DATASET_NAME = f"Qdrant/dbpedia-entities-openai3-{MODEL_NAME}-{DIMENSIONS}-100K"
    collection_name = f"dbpedia-{MODEL_NAME}-{DIMENSIONS}"
    embedding_column_name = f"{MODEL_NAME}-{DIMENSIONS}-embedding"
    dataset = load_dataset(
        DATASET_NAME,
        streaming=False,
        split="train",
    )
    ds = dataset.train_test_split(test_size=0.001, shuffle=True, seed=37)["test"]
    ds = ds.to_pandas().to_dict(orient="records")
    logger.info(f"Loaded {DATASET_NAME} dataset")
    results = []
    with open(f"results-{MODEL_NAME}-{DIMENSIONS}.json", "w+") as f:
        for element in tqdm(ds):
            # print(element.payload["text"])
            # print("Oversampling")
            point = PointStruct(
                id=random.randint(0, 100000),
                vector=element[embedding_column_name],
            )
            ## Running Grid Search
            for oversampling in oversampling_range:
                for rescore in rescore_range:
                    limit_range = [100, 50, 20, 10, 5]
                    for limit in limit_range:
                        try:
                            exact = parameterized_search(
                                point=point,
                                oversampling=oversampling,
                                rescore=rescore,
                                exact=True,
                                collection_name=collection_name,
                                limit=limit,
                            )
                            hnsw = parameterized_search(
                                point=point,
                                oversampling=oversampling,
                                rescore=rescore,
                                exact=False,
                                collection_name=collection_name,
                                limit=limit,
                            )
                        except Exception as e:
                            print(f"Skipping point: {point}\n{e}")
                            continue

                        exact_ids = [item.id for item in exact]
                        hnsw_ids = [item.id for item in hnsw]
                        # logger.info(f"Exact: {exact_ids}")
                        # logger.info(f"HNSW: {hnsw_ids}")

                        accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)

                        if accuracy is None:
                            continue

                        result = {
                            "query_id": point.id,
                            "oversampling": oversampling,
                            "rescore": rescore,
                            "limit": limit,
                            "accuracy": accuracy,
                        }
                        f.write(json.dumps(result))
                        f.write("\n")
                        logger.info(result)

[32m2024-02-06 12:43:39.742[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mLoaded Qdrant/dbpedia-entities-openai3-text-embedding-3-small-512-100K dataset[0m
  0%|          | 0/100 [00:00<?, ?it/s][32m2024-02-06 12:43:41.275[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1m{'query_id': 88849, 'oversampling': 1.0, 'rescore': True, 'limit': 100, 'accuracy': 0.44}[0m
[32m2024-02-06 12:43:41.805[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1m{'query_id': 88849, 'oversampling': 1.0, 'rescore': True, 'limit': 50, 'accuracy': 0.44}[0m
[32m2024-02-06 12:43:42.331[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1m{'query_id': 88849, 'oversampling': 1.0, 'rescore': True, 'limit': 20, 'accuracy': 0.35}[0m
[32m2024-02-06 12:43:42.858[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m68[0m - [1m{'query_id': 88849, 'oversampling': 1.0, 'rescore': True, 