In [None]:
from dotenv import load_dotenv
load_dotenv()

from os import getenv
from time import sleep
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
from tqdm import tqdm
import numpy as np
from scipy.sparse import csr_matrix, vstack
from transformers import AutoTokenizer

from utils.storage import list_sparse_vector_files, download_sparse_vectors

In [None]:
TOP_K = 128

In [None]:
MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.vocab_size

In [None]:
DB_PATH = "./output/sparse_index.db"
conn = sqlite3.connect(DB_PATH, check_same_thread=False)
cursor = conn.cursor()

cursor.execute('''
	CREATE TABLE IF NOT EXISTS documents (
		id INTEGER PRIMARY KEY AUTOINCREMENT,
		filename TEXT UNIQUE
	);
''')
cursor.execute('''
	CREATE TABLE IF NOT EXISTS inverted_index (
		term INTEGER,
		document_id INTEGER,
		score REAL,
		PRIMARY KEY (term, document_id)
		FOREIGN KEY (document_id) REFERENCES documents(id)
	);
''')
cursor.execute('''
	CREATE INDEX IF NOT EXISTS idx_term ON inverted_index (term);
''')

In [None]:
files = list_sparse_vector_files()
len(files)

In [None]:
cursor.executemany(
	"INSERT OR IGNORE INTO documents (filename) VALUES (?)",
	[(filename,) for filename in files]
)
conn.commit()

cursor.execute("SELECT id, filename FROM documents")
documents = cursor.fetchall()

documentIndexMap = {row[1]: row[0] for row in documents}

In [None]:
lock = Lock()

def process_sparse_vectors(filename):
	vectors = download_sparse_vectors(filename)
	chunks = []

	for idx, vector in vectors.items():
		indices = list(int(k) for k in vector.keys())
		values = list(float(v) for v in vector.values())
		data = csr_matrix((values, (np.zeros_like(values), indices)), shape=(1, tokenizer.vocab_size))
		chunks.append(data)

	if len(chunks) == 0:
		return

	vector = chunks[0] if len(chunks) == 1 else vstack(chunks).max(axis=0).tocsr()
	data = sorted(zip(vector.indices, vector.data), key=lambda x: x[1], reverse=True)[:TOP_K]

	insertionData = []
	for termIndex, score in data:
		if score > 0:
			insertionData.append((int(termIndex), documentIndexMap[filename], float(score)))

	if not insertionData:
		return

	with lock:
		cursor.executemany(
			"INSERT OR REPLACE INTO inverted_index (term, document_id, score) VALUES (?, ?, ?)",
			insertionData
		)
		conn.commit()

with ThreadPoolExecutor() as executor:
	_ = list(tqdm(executor.map(process_sparse_vectors, files), total=len(files)))