<a href="https://colab.research.google.com/github/jankovicsandras/plpgsql_bm25/blob/main/Postgres_hybrid_search_RRF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Postgres hybrid search with Reciprocal Rank Fusion

https://github.com/jankovicsandras/plpgsql_bm25

https://github.com/pgvector/pgvector

## installing PostgreSQL

In [None]:
! sudo apt install gnupg2 wget nano
! sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list'
! curl -fsSL https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo gpg --dearmor -o /etc/apt/trusted.gpg.d/postgresql.gpg
! sudo apt update
! sudo apt install postgresql-16 postgresql-contrib-16 postgresql-server-dev-16
! service postgresql start
! sudo -u postgres psql -c "CREATE USER root WITH SUPERUSER"

## init psycopg2 helper

In [None]:
! pip install psycopg2

import psycopg2


def msq( t, verbose=False ) :
  res = []
  with psycopg2.connect("dbname=postgres user=root") as conn:
    with conn.cursor() as cur:
      try :
        cur.execute(t)
        res = cur.fetchall()
        if verbose :
          for cdesc in cur.description :
            print(cdesc[0],'|',end='')
          print('')
          for r in res:
            for c in r:
              print(c,'|',end='')
            print('')
      except Exception as ex :
        print('!!!! msq() ERROR ',ex,'|',t,'|')
  return res


def iscorrectintopk(correct, res) :
  cr = correct.replace('"','\'')
  if len(res) < 1 : return 0
  for cell in res[0] :
    if str(cell).replace('"','\'') == cr : return 2
  for ri in range(1,len(res)) :
    for cell in res[ri] :
      if str(cell).replace('"','\'') == cr : return 1
  return 0


## installing pgvector

In [None]:
! git clone --branch v0.8.0 https://github.com/pgvector/pgvector.git
! cd pgvector && make
! cd pgvector && sudo make install
msq('CREATE EXTENSION vector;')

## plpgsql_bm25

In [None]:
# downloading from https://github.com/jankovicsandras/plpgsql_bm25
! wget -nc https://raw.githubusercontent.com/jankovicsandras/plpgsql_bm25/refs/heads/main/plpgsql_bm25.sql

# execute to load the functions
! psql postgresql://@/postgres -f /content/plpgsql_bm25.sql

# downloading from https://github.com/jankovicsandras/plpgsql_bm25
! wget -nc https://raw.githubusercontent.com/jankovicsandras/plpgsql_bm25/refs/heads/main/plpgsql_bm25rrf.sql

# execute to load the functions
! psql postgresql://@/postgres -f /content/plpgsql_bm25rrf.sql

In [None]:
! pip install sentence-transformers

## Test datasets


In [None]:

! wget -nc https://zav-public.s3.amazonaws.com/inpars/msmarco_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/robust04_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/nq_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/trec_covid_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/fiqa_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/dbpedia_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/scidocs_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/scifacts_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/arguana_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/bioasq_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/climate_fever_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/cqadupstack_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/fever_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/hotpotqa_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/nfcorpus_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/quora_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/signal_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/touche_synthetic_queries_100k.jsonl
! wget -nc https://zav-public.s3.amazonaws.com/inpars/trec_news_synthetic_queries_100k.jsonl


In [None]:
import json, random
from sentence_transformers import SentenceTransformer

verbose = False
k = 99999
slimit = 99999
samplenum = 200
embeddermodel = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
vectordim = 384
datasetnames = ['msmarco', 'robust04', 'nq', 'trec_covid', 'fiqa', 'dbpedia', 'scidocs', 'scifacts', 'arguana', 'bioasq', 'climate_fever', 'cqadupstack', 'fever', 'hotpotqa', 'nfcorpus', 'quora', 'signal', 'touche', 'trec_news']
fnpostfix = '_synthetic_queries_100k.jsonl'
columnname = 'doc_text'
idcolumnname = 'id'
embeddingcolumnname = 'embedding'
algo = 'plus'

# data and question containers
alldata = {}
questions = {}
questionsolutions = {}
embeddedquestions = {}

for dsname in datasetnames : #[:2]
  print('\n---- ',dsname)
  # load
  with open( dsname+fnpostfix ) as f:
    alldata[dsname] = [json.loads(line) for line in f]
  # shuffle and truncate to samplenum
  random.shuffle(alldata[dsname])
  alldata[dsname] = alldata[dsname][:samplenum]
  # display
  #for i in range(0,10) :
  #  print(alldata[dsname][i])

  # doc_text embeddings
  embeddedcorpus = embeddermodel.encode( [ thisdoc[columnname] for thisdoc in alldata[dsname] ] )
  for i, item in enumerate(embeddedcorpus) :
    alldata[dsname][i]['embedding'] = embeddedcorpus[i]

  # export to CSV for Postgres
  csvfilename = dsname+'_items.csv'
  with open(csvfilename,'w+') as f:
    f.write( 'id;'+columnname+'\n' )
    for i in range( 0, len(alldata[dsname]) ) :
      f.write( str(i+1)+';"' + alldata[dsname][i][columnname].replace('"','\'')+'";'+ json.dumps( alldata[dsname][i]['embedding'].tolist() ) + '\n')
  print( csvfilename, 'is created' )

  # suffling again and creating questions and questionsolutions
  questions[dsname] = []
  questionsolutions[dsname] = []
  random.shuffle(alldata[dsname])
  for i in range( 0, len(alldata[dsname]) ) :
    questions[dsname].append( alldata[dsname][i]['question'] )
    questionsolutions[dsname].append( alldata[dsname][i][columnname] )
  # question embeddings
  embeddedquestions[dsname] = embeddermodel.encode( questions[dsname] )

  # Postgres creating items table by importing from CSV
  tablename = dsname + '_items_table'
  msq('DROP TABLE IF EXISTS '+tablename+' CASCADE;')
  msq('CREATE TABLE '+tablename+' (id SERIAL, '+columnname+' TEXT, embedding vector('+str(vectordim)+'));')
  msq('COPY '+tablename+' FROM \'/content/'+csvfilename+'\' DELIMITER \';\' CSV HEADER;')
  msq('SELECT * FROM '+tablename+';')


## running the questions


In [None]:
import json, time, math

stats = {
  'total': {
    'questions':0,
    'plpgsql_bm25_speed total':0,
    'plpgsql_bm25_correct total':0,
    'pgvector_speed total':0,
    'pgvector_correct total':0,
    'python_rrf_speed total':0,
    'python_rrf_correct total':0,
    'bm25rrf_speed total':0,
    'bm25rrf_correct total':0
  },
  'diffs':[]
}
stats2 = {}

for dsname in datasetnames : # [:2]
  print('\n----',dsname,'----------------------------------------------------\n')
  # table and file names
  tablename = dsname + '_items_table'

  stats[dsname] = {
    'plpgsql_bm25_speed':[],
    'plpgsql_bm25_correct':[],
    'pgvector_speed':[],
    'pgvector_correct':[],
    'python_rrf_speed':[],
    'python_rrf_correct':[],
    'bm25rrf_speed':[],
    'bm25rrf_correct':[]
  }

  # Postgres has already "items" table, tokenization is built-in
  msq( 'SELECT bm25createindex( \''+tablename+'\', \''+columnname+'\', \''+algo+'\' );', verbose )

  # Running the questions
  for qi,q in enumerate(questions[dsname]) :
    # print question
    if verbose :
      print('\n----Question',qi,':',q)
      if questionsolutions[dsname] and qi<len(questionsolutions[dsname]) : print('Solution:',questionsolutions[dsname][qi])
    stats['total']['questions'] += 1

    # plpgsql_bm25 BM25 search
    if verbose : print('---- plpgsql_bm25')
    starttime = time.time()
    mres = msq( 'SELECT * FROM bm25topk( \''+tablename+'\', \''+columnname+'\', \''+q.replace("'","\'\'")+'\', '+str(k)+', \''+algo +'\' );', verbose )
    stats[dsname]['plpgsql_bm25_speed'].append( (time.time()-starttime) )

    # pgvector search
    if verbose : print('---- pgvector')
    starttime = time.time()
    mres2 = msq( 'SELECT * FROM '+tablename+' ORDER BY embedding <-> \''+json.dumps( embeddedquestions[dsname][qi].tolist() )+'\' LIMIT '+str(k)+';', verbose )
    stats[dsname]['pgvector_speed'].append( (time.time()-starttime) )

    # Reciprocal Rank Fusion https://medium.com/@devalshah1619/mathematical-intuition-behind-reciprocal-rank-fusion-rrf-explained-in-2-mins-002df0cc5e2a
    # rrfscore =  1 / (60+rank1) + 1 / (60+rank2)
    starttime = time.time()
    mres3 = []
    bm25rankeddocs = [ r[2] for r in mres ]
    for ri, r in enumerate(mres2) :
      rrfscore = 1 / ( ri + 61 )
      if r[1] in bm25rankeddocs :
        rrfscore += 1 / ( bm25rankeddocs.index(r[1]) + 61 )
      mres3.append( [ r[0], rrfscore, r[1] ] )
    mres3.sort(key=lambda x : x[1], reverse=True)
    stats[dsname]['python_rrf_speed'].append( (time.time()-starttime) )
    mres3 = mres3[:5]
    stats[dsname]['python_rrf_correct'].append( iscorrectintopk( questionsolutions[dsname][qi], mres3 ) )
    if verbose :
      print('---- Python RRF')
      print('id | rrfscore | doctext')
      for r in mres3 :
        print(r[0],r[1],r[2])

    # bm25rrf
    if verbose : print('---- bm25rrf')
    starttime = time.time()
    mres4 = msq( 'SELECT * FROM bm25rrf(\''+q.replace("'","\'\'")+'\', \''+json.dumps( embeddedquestions[dsname][qi].tolist() )+'\', \''+tablename+'\', \''+idcolumnname+'\', \''+columnname+'\', \''+embeddingcolumnname+'\', \''+str(slimit)+'\', \''+algo +'\');', verbose )
    stats[dsname]['bm25rrf_speed'].append( (time.time()-starttime) )
    mres4 = mres4[:5]
    stats[dsname]['bm25rrf_correct'].append( iscorrectintopk( questionsolutions[dsname][qi], mres4 ) )

    # correct check after RRF
    mres = mres[:5]
    mres2 = mres2[:5]
    stats[dsname]['plpgsql_bm25_correct'].append( iscorrectintopk( questionsolutions[dsname][qi], mres ) )
    stats[dsname]['pgvector_correct'].append( iscorrectintopk( questionsolutions[dsname][qi], mres2 ) )

    #### End of questions loop

  # format and print stats
  stats2[dsname] = {}
  for statname in stats[dsname] :
    stats2[dsname][statname+' total'] = [ sum(stats[dsname][statname]) ]
  #
  for statname in stats2[dsname] :
    print( statname, stats2[dsname][statname][0] )
    stats['total'][statname] += stats2[dsname][statname][0]
  #
  if stats2[dsname]['python_rrf_correct total'][0] != stats2[dsname]['bm25rrf_correct total'][0] :
    stats['diffs'].append( {'dsname':dsname, 'python_rrf_correct total': stats2[dsname]['python_rrf_correct total'][0], 'bm25rrf_correct total': stats2[dsname]['bm25rrf_correct total'][0]} )

  #### End of datasets loop

# print results
print('\n-------- Total stats')
for k in stats['total'] :
  print( k, ':', stats['total'][k] )
print('\n-------- diffs')
print( stats['diffs'] )
