In [None]:
import chromadb
client = chromadb.PersistentClient(path="./db")

In [None]:
import requests
from chromadb import Documents, Embeddings, EmbeddingFunction
from dotenv import load_dotenv

class MyEmbeddingFunction(EmbeddingFunction):
    # Default constructor
    def __init__(self, url="", api_key=""):
        load_dotenv()
        self._url = url
        self._api_key = api_key

        if not self._url or not self._api_key:
            raise ValueError("URL and API_KEY must be specified")

        self._session = requests.Session()

    def __call__(self, input: Documents) -> Embeddings:
        response = self._session.post(
            url=self._url,
            headers={
                "Authorization": f"Bearer {self._api_key}",
            },
            json={
                "inputs": input
            }
        )
        response.raise_for_status()
        return response.json()

In [None]:
import os

hf_ef = MyEmbeddingFunction(api_key=os.getenv("HF_API_KEY"), url=os.getenv("EMBEDDED_ENDPOINT"))

collection = client.get_or_create_collection(name="chromadb_demo2_hf_model", embedding_function=hf_ef)

In [None]:
import csv

def get_csv_file(filename):
    # Read the data from the CSV file
    with open(filename, "r") as f:
        # Skip the header row
        next(f)
        reader = csv.reader(f)
        return list(reader)

# Get the data from the CSV file
data = get_csv_file("chinese_menu_items.csv")



In [None]:
# Add the data to the collection
collection.add(
    ids=[arr[0] for arr in data],
    documents=[arr[1] for arr in data],
    metadatas=[{"cuisine":"chinese"} for _ in data]
)

In [None]:
chunk_size = 25

# Flatten the data into two lists
ids=[arr[0] for arr in data]
docs=[arr[1] for arr in data]

# Split the data into chunks
id_chunks = [ids[i:i + chunk_size] for i in range(0, len(ids), chunk_size)]
doc_chunks = [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]

doc_chunks

In [None]:
# Add the data to the collection

for id_chunk, doc_chunk in zip(id_chunks, doc_chunks):
    collection.add(
        ids=id_chunk,
        documents=doc_chunk,
        metadatas=[{"cuisine":"chinese"} for _ in id_chunk]
    )

In [None]:
result = collection.query(
    query_texts=["greasy"],
    n_results=10,
    include=["documents"],
    where={"cuisine": "chinese"}
)

result["documents"]