In [1]:
import uuid
import polars as pl

from chromadb import HttpClient
from sentence_transformers import SentenceTransformer
from typing import List


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_df = pl.read_parquet('../data/train_yelp_review_embedded.parquet')
train_df.head()

label,text,embeddings
i64,str,"array[f32, 384]"
4,"""dr. goldberg offers everything…","[0.045555, -0.052686, … 0.002462]"
1,"""Unfortunately, the frustration…","[0.029399, -0.017978, … 0.011826]"
3,"""Been going to Dr. Goldberg for…","[0.038964, -0.037059, … 0.004994]"
3,"""Got a letter in the mail last …","[0.089701, -0.089196, … 0.030734]"
0,"""I don't know what Dr. Goldberg…","[-0.005041, 0.045442, … -0.028141]"


In [3]:
embeding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device="mps", cache_folder="../cache")

In [4]:
def create_embeddigns(entry):
    output = embeding_model.encode([entry['text']], device="mps")
    return output[0]

In [None]:
# texts = train_df.select('text').to_numpy().flatten()
# embeddings = embeding_model.encode(texts, device="mps")

In [None]:
# train_df = train_df.with_columns(
#     embeddings = pl.Series(name="embeddings", values=embeddings)
# )

In [None]:
# train_df.write_parquet("../data/train_yelp_review_embedded.parquet")

In [None]:
# train_df.head()

In [5]:
client = HttpClient(host='localhost', port=8000)
client.delete_collection("yelp_review")
collection = client.get_or_create_collection("yelp_review", configuration={
        "hnsw": {
            "space": "cosine",
            "ef_construction": 200
        }
    })

In [6]:
train_df = train_df.with_columns(
    pl.arange(0, train_df.height).map_elements(
        lambda _: str(uuid.uuid4()),
        return_dtype=pl.Utf8
    ).alias("ids")
)

In [7]:
ids = train_df.select('ids').to_numpy().flatten().tolist()
text = train_df.select('text').to_numpy().flatten().tolist()
stars = list(map(lambda x: {"star": x}, train_df.select('label').to_numpy().flatten().tolist()))
embeddings = train_df.select('embeddings').to_numpy().flatten().tolist()

In [8]:
total = len(ids)
print(f"Total rows to add: {total}")
BATCH_SIZE = 5000

for start in range(0, total, BATCH_SIZE):
    end = min(start + BATCH_SIZE, total)

    batch_ids = ids[start:end]
    batch_text = text[start:end]
    batch_stars = stars[start:end]
    batch_embeddings = embeddings[start:end]

    collection.add(
        ids=batch_ids,
        embeddings=batch_embeddings,
        metadatas=batch_stars,
        documents=batch_text,
    )
    print(f"Added rows {start}-{end-1} ({end-start} records)")

Total rows to add: 650000
Added rows 0-4999 (5000 records)
Added rows 5000-9999 (5000 records)
Added rows 10000-14999 (5000 records)
Added rows 15000-19999 (5000 records)
Added rows 20000-24999 (5000 records)
Added rows 25000-29999 (5000 records)
Added rows 30000-34999 (5000 records)
Added rows 35000-39999 (5000 records)
Added rows 40000-44999 (5000 records)
Added rows 45000-49999 (5000 records)
Added rows 50000-54999 (5000 records)
Added rows 55000-59999 (5000 records)
Added rows 60000-64999 (5000 records)
Added rows 65000-69999 (5000 records)
Added rows 70000-74999 (5000 records)
Added rows 75000-79999 (5000 records)
Added rows 80000-84999 (5000 records)
Added rows 85000-89999 (5000 records)
Added rows 90000-94999 (5000 records)
Added rows 95000-99999 (5000 records)
Added rows 100000-104999 (5000 records)
Added rows 105000-109999 (5000 records)
Added rows 110000-114999 (5000 records)
Added rows 115000-119999 (5000 records)
Added rows 120000-124999 (5000 records)
Added rows 125000-129

In [None]:
collection