In [None]:
from dotenv import load_dotenv
load_dotenv()

from os import getenv
from random import random
import re
from glob import glob
from orjson import loads
from tqdm import tqdm
import numpy as np
import faiss
from supabase import create_client, Client

In [None]:
supabase = create_client(getenv("SUPABASE_URL"), getenv("SUPABASE_KEY"))

In [None]:
EMBEDDING_DIM = 3072
index = faiss.IndexIDMap(
	faiss.IndexScalarQuantizer(
		EMBEDDING_DIM,
		faiss.ScalarQuantizer.QT_8bit
	)
)

In [None]:
files = sorted(glob("/Volumes/Vault/OpenAI Embeddings/*.jsonl"))
len(files)

In [None]:
table = supabase.table("documents")

documents = []
offset = 0
while True:
	response = table.select("*").range(offset, offset + 1000).execute()
	if not response.data:
		break

	documents.extend(response.data)
	offset += 1000

documentIndexMap = {row["document"]: row["id"] for row in documents}

In [None]:
def train_index():
	embeddings = []

	for filename in tqdm(files):
		with open(filename, "r") as f:
			for line in f:
				if random() < 0.9:
					continue

				data = loads(line)
				body = data["response"]["body"]
				if "data" not in body:
					continue

				embedding = np.array(body["data"][0]["embedding"], dtype=np.float32)
				embeddings.append(embedding)

	index.train(np.array(embeddings, dtype=np.float32))

train_index()

In [None]:
def process_dense_vectors(filename):
	identifiers = []
	embeddings = []

	with open(filename, "r") as f:
		for line in f:
			data = loads(line)
			body = data["response"]["body"]
			if "data" not in body:
				continue

			identifier = documentIndexMap[re.sub(r'_\d+_\d+$', '', data["custom_id"])]
			embedding = np.array(body["data"][0]["embedding"], dtype=np.float32)
			identifiers.append(identifier)
			embeddings.append(embedding)

	identifiers = np.array(identifiers, dtype=np.int64)
	embeddings = np.array(embeddings, dtype=np.float32)

	index.add_with_ids(embeddings, identifiers)

for filename in tqdm(files):
	process_dense_vectors(filename)

index.ntotal

In [None]:
faiss.write_index(index, "output/dense_index.faiss")