In [None]:
from dotenv import load_dotenv
load_dotenv()

from os import getenv
from time import sleep
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from tqdm import tqdm
import numpy as np
from scipy.sparse import csr_matrix, vstack
from transformers import AutoTokenizer
from supabase import create_client, Client

from utils.storage import list_sparse_vector_files, download_sparse_vectors

In [None]:
supabase = create_client(getenv("SUPABASE_URL"), getenv("SUPABASE_KEY"))

In [None]:
TOP_K = 128

In [None]:
MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.vocab_size

In [None]:
files = list_sparse_vector_files()
len(files)

In [None]:
table = supabase.table("documents")

table.upsert(
	[{"document": filename} for filename in files],
	ignore_duplicates=True,
	on_conflict=["document"]
).execute()

documents = []
offset = 0
while True:
	response = table.select("*").range(offset, offset + 1000).execute()
	if not response.data:
		break

	documents.extend(response.data)
	offset += 1000

documentIndexMap = {row["document"]: row["id"] for row in documents}

In [None]:
lock = Lock()

def process_sparse_vectors(filename):
	vectors = download_sparse_vectors(filename)
	chunks = []

	for idx, vector in vectors.items():
		indices = list(int(k) for k in vector.keys())
		values = list(float(v) for v in vector.values())
		data = csr_matrix((values, (np.zeros_like(values), indices)), shape=(1, tokenizer.vocab_size))
		chunks.append(data)

	if len(chunks) == 0:
		return

	vector = chunks[0] if len(chunks) == 1 else vstack(chunks).max(axis=0).tocsr()
	data = sorted(zip(vector.indices, vector.data), key=lambda x: x[1], reverse=True)[:TOP_K]

	insertionData = []
	for termIndex, score in data:
		if score > 0:
			insertionData.append({ "term": int(termIndex), "document_id": documentIndexMap[filename], "score": float(score) })

	while True:
		try:
			with lock:
				supabase.table("sparse_index").upsert(insertionData).execute()
			break
		except KeyboardInterrupt as e:
			raise e
		except Exception as e:
			print(e)
			sleep(5)

with ThreadPoolExecutor() as executor:
	_ = list(tqdm(executor.map(process_sparse_vectors, files), total=len(files)))