BM25 + Neural Rerankers

In [1]:
!pip install -U beir
!pip install -U pandas
!pip install 'elasticsearch<7.14.0'



In [2]:
import pathlib, os
import time
import pandas as pd
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.reranking.models import CrossEncoder
from beir.reranking import Rerank

  from tqdm.autonotebook import tqdm


Elasticsearch

In [3]:
"""
!rm -rf elasticsearch-7.9.2
"""

'\n!rm -rf elasticsearch-7.9.2\n'

In [4]:
if not os.path.exists('elasticsearch-oss-7.9.2-linux-x86_64.tar.gz'):
  !wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-oss-7.9.2-linux-x86_64.tar.gz
  !wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-oss-7.9.2-linux-x86_64.tar.gz.sha512
  !tar -xzf elasticsearch-oss-7.9.2-linux-x86_64.tar.gz
  !sudo chown -R daemon:daemon elasticsearch-7.9.2/
  !shasum -a 512 -c elasticsearch-oss-7.9.2-linux-x86_64.tar.gz.sha512 

In [5]:
%%bash --bg

sudo -H -u daemon elasticsearch-7.9.2/bin/elasticsearch

Starting job # 0 in a separate thread.


In [6]:
time.sleep(20)

In [7]:
%%bash

ps -ef | grep elasticsearch

root         801       1  0 08:52 ?        00:00:00 sudo -H -u daemon elasticsearch-7.9.2/bin/elasticsearch
daemon       802     801  4 08:52 ?        00:00:24 /content/elasticsearch-7.9.2/jdk/bin/java -Xshare:auto -Des.networkaddress.cache.ttl=60 -Des.networkaddress.cache.negative.ttl=10 -XX:+AlwaysPreTouch -Xss1m -Djava.awt.headless=true -Dfile.encoding=UTF-8 -Djna.nosys=true -XX:-OmitStackTraceInFastThrow -XX:+ShowCodeDetailsInExceptionMessages -Dio.netty.noUnsafe=true -Dio.netty.noKeySetOptimization=true -Dio.netty.recycler.maxCapacityPerThread=0 -Dio.netty.allocator.numDirectArenas=0 -Dlog4j.shutdownHookEnabled=false -Dlog4j2.disable.jmx=true -Djava.locale.providers=SPI,COMPAT -Xms1g -Xmx1g -XX:+UseG1GC -XX:G1ReservePercent=25 -XX:InitiatingHeapOccupancyPercent=30 -Djava.io.tmpdir=/tmp/elasticsearch-4541497722120165995 -XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=data -XX:ErrorFile=logs/hs_err_pid%p.log -Xlog:gc*,gc+age=trace,safepoint:file=logs/gc.log:utctime,pid,tags:filecou

In [8]:
%%bash

curl -sX GET "localhost:9200/"

{
  "name" : "8d69a4dda403",
  "cluster_name" : "elasticsearch",
  "cluster_uuid" : "TmDaTiYKRPubdoUYSyCXQA",
  "version" : {
    "number" : "7.9.2",
    "build_flavor" : "oss",
    "build_type" : "tar",
    "build_hash" : "d34da0ea4a966c4e49417f2da2f244e3e97b4e6e",
    "build_date" : "2020-09-23T00:45:33.626720Z",
    "build_snapshot" : false,
    "lucene_version" : "8.6.2",
    "minimum_wire_compatibility_version" : "6.8.0",
    "minimum_index_compatibility_version" : "6.0.0-beta1"
  },
  "tagline" : "You Know, for Search"
}


In [9]:
def eval_metrics(model_name, ndcg, _map, recall, precision):
  c_map, c_ndcg, c_pre, c_rec = 'MAP@10', 'NDCG@10', 'P@10', 'Recall@10'
  eval_dict = {c_map: [_map[c_map]], c_ndcg: [ndcg[c_ndcg]], c_pre: [precision[c_pre]], c_rec: [recall[c_rec]]}
  eval_df = pd.DataFrame(data=eval_dict)
  eval_df.index = [model_name]
  return eval_df

In [10]:
hostname = 'localhost' 
index_name = 'scifact'
dataset = 'scifact'
url = 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip'.format(dataset)
out_dir = os.path.join(os.getcwd(), 'datasets')
data_path = util.download_and_unzip(url, out_dir)
corpus, queries, qrels = GenericDataLoader(data_path).load(split='test') # 'test', 'train', 'dev
model_bm25 = BM25(index_name=index_name, hostname=hostname, initialize=True) # initialize=True : reindex
retriever = EvaluateRetrieval(model_bm25)
results_bm25 = retriever.retrieve(corpus, queries)

  0%|          | 0/5183 [00:00<?, ?it/s]

  0%|          | 0/5183 [00:00<?, ?docs/s]
que: 100%|██████████| 3/3 [00:19<00:00,  6.45s/it]


In [11]:
ndcg, _map, recall, precision = retriever.evaluate(qrels, results_bm25, retriever.k_values)
print(eval_metrics('BM25', ndcg, _map, recall, precision))

       MAP@10  NDCG@10     P@10  Recall@10
BM25  0.64383  0.68998  0.09033    0.81644


Reranking

In [12]:
cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-electra-base')
reranker = Rerank(cross_encoder_model, batch_size=128)
rerank_results = reranker.rerank(corpus, queries, results_bm25, top_k=20)

Downloading:   0%|          | 0.00/730 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Batches:   0%|          | 0/47 [00:00<?, ?it/s]

In [13]:
ndcg_rr, _map_rr, recall_rr, precision_rr = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)
print(eval_metrics('Electra', ndcg_rr, _map_rr, recall_rr, precision_rr))

          MAP@10  NDCG@10   P@10  Recall@10
Electra  0.64814  0.69154  0.091     0.8105
