# Qdrant Plugins

> Qdrant functions and classes

In [None]:
#| default_exp plugins.qdrant

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.schemas import Query, Item, DataSourceResponse
from emb_opt.utils import build_batch_from_embeddings
from emb_opt.data_source import DataSourcePlugin, DataSourceModule

try:
    from qdrant_client import QdrantClient
    from qdrant_client.http import models
except:
    warnings.warn('Failed to import Qdrant client - check package install')

### Qdrant Data Plugin

`QdrantDataPlugin` integrates with a [Qdrant](https://qdrant.tech/) database. `search_request_kwargs` can be any valid inputs to a Qdrant [SearchRequest](https://qdrant.tech/documentation/search/#batch-search-api)

In [None]:
#| export

class QdrantDataPlugin(DataSourcePlugin):
    '''
    QdrantDataPlugin - data plugin for working with 
    a qdrant vector database.
    
    The data query will run `k` nearest neighbors against the 
    qdrant collection `collection_name`
    
    Optionally, `item_key` denotes the key in an object's payload 
    corresponding to the item value
    
    `search_request_kwargs` are optional kwargs sent to 
    `models.SearchRequest`
    '''
    def __init__(self,
                 k: int,                                     # k nearest neighbors to return
                 collection_name: str,                       # qdrant collection name
                 qdrant_client: QdrantClient,                # qdrant client
                 item_key: Optional[str]=None,               # key in qdrant payload denoting item value
                 search_request_kwargs: Optional[dict]=None  # optional kwargs for `SearchRequest`
                ):
        
        self.k = k
        self.collection_name = collection_name
        self.qdrant_client = qdrant_client
        self.item_key = item_key
        self.search_request_kwargs = search_request_kwargs if search_request_kwargs else {}
        
    def __call__(self, inputs: List[Query]) -> List[DataSourceResponse]:
        
        search_queries = [models.SearchRequest(vector=i.embedding,
                                               limit=self.k,
                                               with_payload=True,
                                               with_vector=True,
                                               **self.search_request_kwargs
                                              )
                          for i in inputs]
        
        res = self.qdrant_client.search_batch(collection_name=self.collection_name, requests=search_queries)
        
        outputs = []
        for query_idx, result_batch in enumerate(res):
            items = []
            query_data = {'query_distance':[]}
            for query_result in result_batch:
                payload = query_result.payload
                item_value = payload.pop(self.item_key) if self.item_key is not None else None
                item = Item(id=query_result.id,
                            item=item_value,
                            embedding=query_result.vector,
                            score=None,
                            data=payload
                           )
                items.append(item)
                query_data['query_distance'].append(query_result.score)
                
            result = DataSourceResponse(valid=bool(items), data=query_data, query_results=items)
            outputs.append(result)
            
        return outputs

In [None]:
#|eval: false

n_vectors = 1000
d_vectors = 128
n_queries = 5

vectors = np.random.randn(n_vectors, d_vectors)
payloads = [{'rand':np.random.rand(), 'item' : str(np.random.randint(0, 1e6))} for i in range(n_vectors)]
queries = np.random.randn(n_queries, d_vectors)

client = QdrantClient(host="localhost", port=6333)

client.recreate_collection(
    collection_name="test_collection",
    vectors_config=models.VectorParams(size=d_vectors, distance=models.Distance.EUCLID),
)

operation_info = client.upsert(
    collection_name="test_collection",
    points=models.Batch(
        ids=[i for i in range(n_vectors)],
        payloads=payloads,
        vectors=vectors.tolist()
    )
)

client.update_collection(
            collection_name='test_collection',
            optimizer_config=models.OptimizersConfigDiff(
                indexing_threshold=1
            )
        )

search_filter = models.Filter(
    must=[
            models.FieldCondition(
            key="rand",
            range=models.Range(
                gt=None,
                gte=None,
                lt=None,
                lte=0.8,
            ),
        )
    ]
)

data_function = QdrantDataPlugin(5, "test_collection", client, item_key='item', 
                             search_request_kwargs={'filter' : search_filter})

data_module = DataSourceModule(data_function)

batch = build_batch_from_embeddings(queries)
batch2 = data_module(batch)

for q in batch2:
    for r in q:
        assert r.internal.parent_id == q.id