# qdrant

> qdrant backend API

In [None]:
#| default_exp backends.qdrant

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import QueryResult, VectorDatabase, dataset_from_query_results

try:
    from qdrant_client import QdrantClient
    from qdrant_client.http import models
except:
    warnings.warn('Failed to import Qdrant client - check package install')

  from .autonotebook import tqdm as notebook_tqdm


## Qdrant

`QdrantDatabase` integrates with a [Qdrant](https://qdrant.tech/) database. `search_request_kwargs` can be any valid inputs to a Qdrant [SearchRequest](https://qdrant.tech/documentation/search/#batch-search-api)

In [None]:
#| export

class QdrantDatabase(VectorDatabase):
    'Qdrant backend'
    def __init__(self,
                 qdrant_client: QdrantClient, # qdrant client
                 collection_name: str, # qdrant collection name
                 k: int, # return `k` results per query
                 search_request_kwargs: Optional[dict]=None # kwargs for `SearchRequest`
                ):
        self.client = qdrant_client
        self.collection_name = collection_name
        self.k = k
        self.search_request_kwargs = search_request_kwargs if search_request_kwargs else {}
    
    def query(self, query_vectors: np.ndarray) -> Dataset:
        
        search_queries = [
                models.SearchRequest(vector=list(i), 
                                     limit=self.k, 
                                     with_payload=True, 
                                     with_vector=True,
                                     **self.search_request_kwargs
                                    ) 
                for i in query_vectors
                    ]
        
        res = self.client.search_batch(
            collection_name=self.collection_name,
            requests=search_queries
        )
        
        results = []
        for query_idx, result_batch in enumerate(res):
            for point in result_batch:
                result = QueryResult(query_idx, 
                                     point.id, 
                                     point.vector, 
                                     point.score, point.payload)
                
                results.append(result)
        
        return dataset_from_query_results(results)

In [None]:
#|eval: false

client = QdrantClient(host="localhost", port=6444)

res = client.retrieve(
    collection_name="zinc",
    ids=[0, 100, 1000],
    with_vectors=True
)
query_vecs = np.array([i.vector for i in res])

qdrant_db = QdrantDatabase(client, 'zinc', 10)

search_filter = models.Filter(
    must=[
            models.FieldCondition(
            key="preds",
            range=models.Range(
                gt=None,
                gte=None,
                lt=None,
                lte=6.0,
            ),
        )
    ]
)

qdrant_db = QdrantDatabase(client, 'zinc', 10, {'filter' : search_filter})

res = qdrant_db.query(query_vecs)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()