In [None]:
import chromadb
client = chromadb.PersistentClient(path="./db")

In [None]:
import os
import requests
from chromadb import Documents, Embeddings, EmbeddingFunction
from dotenv import load_dotenv

class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self, api_key="", enpoint_url=""):
        load_dotenv()
        self._api_key = api_key or os.getenv('API_KEY')
        self._endpoint_url = enpoint_url or os.getenv('ENDPOINT_URL')
        if not self._api_key or not self._endpoint_url:
            raise ValueError("API key and endpoint URL must be provided.")
        self._session = requests.Session()

    def __call__(self, input: Documents) -> Embeddings:
        # Send the input to the embedding endpoint and get the embeddings
        with self._session as s:
            response = s.post(
                self._endpoint_url,
                json={
                    "inputs": input,
                    "parameters": {
                        "wait_for_model": "true"
                    }
                },
                headers={
                    "authorization": f"Bearer {self._api_key}"
                }
            )
        response.raise_for_status()  # Raise an exception for HTTP errors
        return response.json()

In [None]:
import os

collection = client.get_or_create_collection(name="chromadb_demo2_hf_model", embedding_function=MyEmbeddingFunction(api_key=os.getenv("HF_API_KEY"), enpoint_url=os.getenv("EMBEDDED_ENDPOINT")))

In [None]:
import csv

def get_csv_file(filename):
    # Read the data from the CSV file
    with open(filename, "r") as f:
        # Skip the header row
        next(f)
        reader = csv.reader(f)
        return list(reader)

# Get the data from the CSV file
data = get_csv_file("chinese_menu_items.csv")



In [None]:
# Add the data to the collection
collection.add(
    ids=[arr[0] for arr in data],
    documents=[arr[1] for arr in data],
    metadatas=[{"cuisine":"chinese"} for _ in data]
)

In [None]:
chunk_size = 25

# Flatten the data into two lists
ids=[arr[0] for arr in data]
docs=[arr[1] for arr in data]

# Split the data into chunks
id_chunks = [ids[i:i + chunk_size] for i in range(0, len(ids), chunk_size)]
doc_chunks = [docs[i:i + chunk_size] for i in range(0, len(docs), chunk_size)]

In [None]:
# Add the data to the collection

for id_chunk, doc_chunk in zip(id_chunks, doc_chunks):
    collection.add(
        ids=id_chunk,
        documents=doc_chunk,
        metadatas=[{"cuisine":"chinese"} for _ in id_chunk]
    )

In [None]:
result = collection.query(
    query_texts=["greasy noodles"],
    n_results=5,
    include=["documents"],
    where={"cuisine": "chinese"}
)

result["documents"]