In [15]:
from typing import List
import numpy as np
from faker import Faker
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams
from qdrant_client.http.models import PointStruct
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from qdrant_client.http.models import CollectionStatus
# docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z dreg.sbmt.io/dhub/qdrant/qdrant

def create_embedding():
    return np.random.uniform(low=-1.0, high=1.0, size=100).tolist()

In [16]:
class QCollection:
    def __init__(self, 
                 collection_name: str,
                 url: str = "localhost", 
                 port: int = 6333):
        self.client = QdrantClient(url=url, port=port)
        self.collection_name = collection_name
        
    def get_collection_info(self):
        response = self.client.get_collection(collection_name=self.collection_name)
#         assert response.status == CollectionStatus.GREEN
#         assert response.vectors_count == 0
        return response
    
    def create_collection(self, emb_dim: int = 100):
        self.client.create_collection(
            collection_name=self.collection_name,
            vectors_config=VectorParams(size=emb_dim, distance=Distance.COSINE, on_disk=False),
        )
    
    def delete_collection(self):
        self.client.delete_collection(collection_name=self.collection_name)
        
    def payload_index_store_id(self):
        self.client.create_payload_index(
            collection_name=self.collection_name,
            field_name="store_id",
            field_schema="integer",
        )

In [17]:
class QDocument:
    def __init__(self, 
                 collection_name: str,
                 url: str = "localhost", 
                 port: int = 6333,):
        self.client = QdrantClient(url=url, port=port)
        self.collection_name = collection_name
        
    def add_document(self, doc_id: int, embedding: List[float], metadata: dict):
        # points with the same id will be overwritten when re-uploaded.
        self.client.upsert(
            collection_name=self.collection_name,
            points=[
                models.PointStruct(
                    id=doc_id,
                    payload=metadata,
                    vector=embedding,
                ),
            ],
        )
    
    def add_documents(self, doc_ids: List[int], embeddings: List[List[float]], metadata: List[dict]):
        self.client.upsert(
            collection_name=self.collection_name,
            points=models.Batch(
                ids=doc_ids,
                payloads=metadata,
                vectors=embeddings,
            ),
        )
    
    def delete_documents(self, doc_ids: List[int]):
        self.client.delete(
            collection_name=self.collection_name,
            points_selector=models.PointIdsList(
                points=doc_ids,
            ),
        )
        
    def retrieve_documents(self, doc_ids: List[int]) -> dict:
        response = self.client.retrieve(
            collection_name=self.collection_name,
            ids=doc_ids,
            with_payload=True,
            with_vectors=True,
        )
        return response
    
    def count_documents(self, store_id: int) -> dict:
        response = self.client.count(
            collection_name=self.collection_name,
            count_filter=models.Filter(
                must=[
                    models.FieldCondition(key="store_id", match=models.MatchValue(value=store_id)),
                ]
            ),
            exact=True,
        )
        return response
    
    def search(self, embedding: List[float], store_id: int) -> dict:
        response = self.client.search(
            collection_name=self.collection_name,
            query_filter=models.Filter(
                must=[
                    models.FieldCondition(
                        key="store_id",
                        match=models.MatchValue(
                            value=int(store_id),
                        ),
                    )
                ]
            ),
            search_params=models.SearchParams(hnsw_ef=128, exact=False),
            query_vector=embedding,
            limit=24,
            with_vectors=False,
            with_payload=True,
#             score_threshold=100,  # [0;1] cos
        )
        return response
    
    def batch_search(self):
        ...
        # https://qdrant.tech/documentation/concepts/search/#batch-search-api

In [27]:
q_collection = QCollection(collection_name="temp")
q_document = QDocument(collection_name="temp")

In [30]:
q_collection.delete_collection()
q_collection.create_collection()
# q_collection.payload_index_store_id()

In [31]:
print(" ".join([str(elem) + "\n" for elem in list(q_collection.get_collection_info())]))

('status', <CollectionStatus.GREEN: 'green'>)
 ('optimizer_status', <OptimizersStatusOneOf.OK: 'ok'>)
 ('vectors_count', 0)
 ('indexed_vectors_count', 0)
 ('points_count', 0)
 ('segments_count', 8)
 ('config', CollectionConfig(params=CollectionParams(vectors=VectorParams(size=100, distance=<Distance.COSINE: 'Cosine'>, hnsw_config=None, quantization_config=None, on_disk=False), shard_number=1, sharding_method=None, replication_factor=1, write_consistency_factor=1, read_fan_out_factor=None, on_disk_payload=True, sparse_vectors=None), hnsw_config=HnswConfig(m=16, ef_construct=100, full_scan_threshold=10000, max_indexing_threads=0, on_disk=False, payload_m=None), optimizer_config=OptimizersConfig(deleted_threshold=0.2, vacuum_min_vector_number=1000, default_segment_number=0, max_segment_size=None, memmap_threshold=None, indexing_threshold=20000, flush_interval_sec=5, max_optimization_threads=1), wal_config=WalConfig(wal_capacity_mb=32, wal_segments_ahead=0), quantization_config=None))
 ('p

In [32]:
fake_something = Faker()

embs = np.random.uniform(low=-1.0, high=1.0, size=(100000, 100))
doc_ids = list(range(len(embs)))
payloads = []

for i in range(len(embs)):
    payloads.append(
        {
            "store_id": [fake_something.random.randint(0, 100), fake_something.random.randint(0, 100), fake_something.random.randint(0, 100)],
            "product_name": " ".join(fake_something.words()),
            "product_sku": i,
        }
    )

In [33]:
%%time

step = 10000
for t in range(0, len(doc_ids), step):
    q_document.add_documents(doc_ids=doc_ids[t:step + t], 
                             embeddings=embs[t:step + t].tolist(), 
                             metadata=payloads[t:step + t])

CPU times: user 4.61 s, sys: 1.14 s, total: 5.75 s
Wall time: 10.9 s


In [23]:
q_document.retrieve_documents(doc_ids=[1,2,3])[0]

Record(id=3, payload={'product_name': 'management many see', 'product_sku': 3, 'store_id': [4, 39, 65]}, vector=[-0.134738, 0.040225297, 0.078741215, 0.0863309, -0.018019238, -0.10339125, 0.04067779, 0.14605172, -0.14569859, -0.02970379, 0.08513453, 0.039153688, -0.07417077, -0.022878684, -0.13430503, 0.13652149, -0.11096707, 0.12599096, 0.013955012, -0.010172852, -0.01825531, -0.14910343, 0.14927366, -0.15058796, 0.052306868, 0.102939606, -0.044486806, -0.08466696, 0.11035151, 0.046531666, 0.18274309, -0.00039389907, 0.11285905, -0.07643148, -0.05871996, 0.09378919, 0.052145507, 0.12202354, 0.067426324, 0.07736457, 0.17570704, -0.014540996, -0.18167174, -0.052250944, -0.13133423, -0.16808923, -0.04352003, 0.17681769, 0.08209532, 0.058933992, 0.14813001, -0.14794707, -0.076527335, 0.0579725, -0.13905379, 0.121065184, 0.13233972, 0.099033594, -0.041207895, 0.029846486, 0.1282525, 0.10591958, 0.020117737, -0.033288844, 0.056621358, 0.17621106, -0.09782136, -0.1660442, -0.10529423, 0.1787

In [38]:
q_collection.payload_index_store_id()

In [25]:
query_embedding = create_embedding()

In [41]:
%%time
q_document.search(embedding=query_embedding, store_id=10)

CPU times: user 1.86 ms, sys: 4.33 ms, total: 6.19 ms
Wall time: 17.5 ms


[ScoredPoint(id=1992, version=0, score=0.4075537, payload={'product_name': 'day stuff create', 'product_sku': 1992, 'store_id': [10, 2, 25]}, vector=None, shard_key=None),
 ScoredPoint(id=57303, version=5, score=0.33620995, payload={'product_name': 'without lot produce', 'product_sku': 57303, 'store_id': [10, 19, 99]}, vector=None, shard_key=None),
 ScoredPoint(id=39133, version=3, score=0.3061222, payload={'product_name': 'them report dinner', 'product_sku': 39133, 'store_id': [69, 10, 10]}, vector=None, shard_key=None),
 ScoredPoint(id=83610, version=8, score=0.30291712, payload={'product_name': 'part list argue', 'product_sku': 83610, 'store_id': [30, 10, 74]}, vector=None, shard_key=None),
 ScoredPoint(id=72308, version=7, score=0.29928777, payload={'product_name': 'agreement perform field', 'product_sku': 72308, 'store_id': [10, 96, 2]}, vector=None, shard_key=None),
 ScoredPoint(id=51038, version=5, score=0.27924663, payload={'product_name': 'reason reason final', 'product_sku': 

In [175]:
510 / 21, 510 / 8

(24.285714285714285, 63.75)

In [165]:
q_document.count_documents(19)

CountResult(count=9634)

In [167]:
q_document.client.count(
            collection_name=q_document.collection_name,
            exact=True,
        )

CountResult(count=1000000)