# chroma

> chroma backend API

In [None]:
#| default_exp backends.chroma

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
#| export
from emb_opt.imports import *
from emb_opt.core import QueryDataset, QueryResult, VectorDatabase

try:
    import chromadb
    from chromadb.api import Collection
except:
    warnings.warn('Failed to import Chroma - check package install')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export

class ChromaDatabase(VectorDatabase):
    def __init__(self, 
                 chroma_collection: Collection,
                 k: int,
                 query_kwargs: Optional[dict]=None
                ):
        self.collection = collection
        self.k = k
        self.query_kwargs = query_kwargs if query_kwargs else {}
    
    def query(self, query_vectors: np.ndarray) -> QueryDataset:
        
        res = self.collection.query(query_embeddings=[list(i) for i in query_vectors],
                                    include=['documents', 'embeddings', 'metadatas', 'distances'],
                                    n_results = self.k,
                                    **self.query_kwargs
                                   )
        
        results = []
        n_queries = query_vectors.shape[0]
        n_results = self.k
        for query_idx in range(n_queries):
            for result_idx in range(n_results):
                result = QueryResult(query_idx,
                                     res['ids'][query_idx][result_idx],
                                     np.array(res['embeddings'][query_idx][result_idx]),
                                     res['distances'][query_idx][result_idx],
                                     {'document' : res['documents'][query_idx][result_idx],
                                      'metadata' : res['metadatas'][query_idx][result_idx]}
                                    )
                results.append(result)
        
        
        return QueryDataset.from_query_results(results)

In [None]:
chroma_client = chromadb.Client()

def dummy_embedding(texts):
    n_docs = len(texts)
    return np.random.randn(n_docs, 256)

collection = chroma_client.create_collection(name="test", embedding_function=dummy_embedding)

embs = [list(np.random.randn(256)) for i in range(64)]
docs = [f'doc_{i}' for i in range(64)]
ids = [str(i) for i in range(64)]

collection.add(
    embeddings=embs,
    documents=docs,
    ids=ids
)

query_embs = np.random.randn(3, 256)

chroma_db = ChromaDatabase(collection, 5)

query_dataset = chroma_db.query(query_embs)

Using embedded DuckDB without persistence: data will be transient


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()