In [5]:
import psycopg
from pgvector.psycopg import register_vector
from sentence_transformers import SentenceTransformer


In [6]:
db_host = "cdas2"
benchmark_name = "snails"
embedding_model_name = "dunzhang/stella_en_1.5B_v5"

In [7]:
embedding_model = SentenceTransformer(embedding_model_name, trust_remote_code=True)
embedding_model.max_seq_length = 8000
embedding_model.tokenizer.padding_side="right"
embedding_model = embedding_model.to("cuda:2")

In [8]:
db_conn = psycopg.connect(
            f"user=postgres host={db_host} port=5432 password=skalpel dbname={benchmark_name}_skalpel_subsetter_vector_db"
        )
register_vector(db_conn)

In [9]:
cursor = db_conn.cursor()

In [10]:
embedding = embedding_model.encode("""
For the vendor with vendor code V1010, show item names, consignment numbers, and last purchase price for items where the vendor is preferred.
""")


In [11]:
database = "SBODemoUS"
query = """
SELECT table_name, description_embedding <=> (%s) as distance
FROM table_descriptions
WHERE database_name = %s
ORDER BY distance ASC
LIMIT 0.6 * (SELECT COUNT(*) FROM table_descriptions WHERE database_name = %s)
"""
params = [embedding, database, database]
result = cursor.execute(query, params)
[t for t in result.fetchall()]

[('ITM2', 0.3975302572826589),
 ('OVTP', 0.4027715338657396),
 ('AIT2', 0.40286251289564756),
 ('ASP1', 0.4089799833032042),
 ('CPN2', 0.425996223984642),
 ('OAMD', 0.4262357182060471),
 ('ASPP', 0.4280445382707232),
 ('UILM1', 0.4317667941246405),
 ('SPP1', 0.43493703481805945),
 ('OSPP', 0.43755489171733764),
 ('ILM1', 0.43888503029088544),
 ('GTM1', 0.43914265281522613),
 ('PQW1', 0.44011843788596416),
 ('DGP4', 0.44380357785832636),
 ('IOD1', 0.44455532021495303),
 ('IQI1', 0.44455532021495303),
 ('RSC5', 0.4461567036287134),
 ('APRC', 0.45074720058242046),
 ('CTR1', 0.4507885156622722),
 ('OTOB', 0.45141625131987473),
 ('AQI1', 0.45162050987211044),
 ('DLN8', 0.4520256832214853),
 ('OSOIL', 0.4526182200056904),
 ('CIN22', 0.4529393253493639),
 ('OPRC', 0.4533216277389498),
 ('SPP2', 0.45333938493455994),
 ('INV22', 0.4540010416548136),
 ('FCT1', 0.45417685988277223),
 ('ITM6', 0.4547971850994287),
 ('AMR1', 0.4564106139200179),
 ('MIN2', 0.4564184875436045),
 ('UILM3', 0.456425053