In [1]:
from dotenv import load_dotenv
load_dotenv()

from os import getenv
from json import dump
from collections import defaultdict
import sqlite3
import numpy as np
import torch
from scipy.sparse import csr_matrix
from transformers import AutoTokenizer, AutoModelForMaskedLM
from supabase import create_client, Client

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
supabase = create_client(getenv("SUPABASE_URL"), getenv("SUPABASE_KEY"))

In [3]:
MODEL_NAME = "naver/splade-cocondenser-ensembledistil"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.float16)
model.eval()

model.device

device(type='mps', index=0)

In [4]:
def search_index(query: str, k: int):
	tokens = tokenizer(query, return_tensors='pt', padding=False, truncation=False)
	if tokens['input_ids'].shape[1] > 512:
		raise ValueError("Input text is too long")

	tokens = {k: v.to(model.device) for k, v in tokens.items()}

	with torch.no_grad():
		outputs = model(**tokens)

	vector = torch.max(
		torch.log(1 + torch.relu(outputs.logits)) * tokens['attention_mask'].unsqueeze(-1),
		dim=1
	)[0].squeeze()

	indices = vector.nonzero().squeeze().cpu().tolist()
	if not isinstance(indices, list):
		indices = [indices]

	if len(indices) == 0:
		raise []

	weights = vector[indices].cpu().tolist()

	response = supabase.rpc("search_splade_index", {
		"p_query_terms": [{"term": int(idx), "score": float(score)} for idx, score in zip(indices, weights)],
		"p_top_k": k
	}).execute()

	results = {}
	for row in response.data:
		results[row['document']] = row['total_score']

	return results

In [5]:
results = search_index("state of the art image segmentation", k=5)
for filename, distance in results.items():
    print(f"{filename}: {distance:.4f}")

LPeoJSBciKrMfVGnis9rvg: 20.8188
a84sFqjaRXr4N_hLI9vCWQ: 20.3556
XFFwdctZiWtBfwkRRL2l3A: 19.0164
NmRj10H0M6CFxiaySnKtkA: 18.7227
Bau9eDjWEeQLfK-5uVzo-Q: 18.6545
