DENSE RETRIEVAL

In [1]:
!pip install -U beir

Collecting beir
  Downloading beir-0.2.2-py3-none-any.whl (49 kB)
[K     |████████████████████████████████| 49 kB 2.1 MB/s 
[?25hCollecting tensorflow-text
  Downloading tensorflow_text-2.6.0-cp37-cp37m-manylinux1_x86_64.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 8.6 MB/s 
[?25hCollecting elasticsearch
  Downloading elasticsearch-7.15.0-py2.py3-none-any.whl (378 kB)
[K     |████████████████████████████████| 378 kB 32.0 MB/s 
[?25hCollecting sentence-transformers
  Downloading sentence-transformers-2.1.0.tar.gz (78 kB)
[K     |████████████████████████████████| 78 kB 4.6 MB/s 
Collecting faiss-cpu
  Downloading faiss_cpu-1.7.1.post2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.4 MB)
[K     |████████████████████████████████| 8.4 MB 17.6 MB/s 
[?25hCollecting pytrec-eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.11.2-py3-none-any.whl (2.9 MB)
[K     |██████████████████████████

In [2]:
import pathlib, os
import time
import pandas as pd
import random
import requests
import json
import torch
import torch.multiprocessing as mp
from tqdm.notebook import tqdm
from tqdm.autonotebook import trange
from beir import util, LoggingHandler
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval import models
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.datasets.data_loader import GenericDataLoader

hostname = 'localhost' 
dataset = 'scifact'
index_name = dataset
url = 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip'.format(dataset)
out_dir = os.path.join(os.getcwd(), 'datasets')
data_path = util.download_and_unzip(url, out_dir)
corpus, queries, qrels = GenericDataLoader(data_path).load(split='test') # 'test', 'train', 'dev'
corpus_ids = list(corpus.keys())
corpus_list = [corpus[doc_id] for doc_id in corpus_ids]

/content/datasets/scifact.zip:   0%|          | 0.00/2.69M [00:00<?, ?iB/s]

  0%|          | 0/5183 [00:00<?, ?it/s]

In [3]:
def eval_metrics(model_name, ndcg, _map, recall, precision):
  c_map = 'MAP@10'
  c_map_h = 'MAP@100'
  c_map_t = 'MAP@1000'
  c_ndcg = 'NDCG@10'
  c_ndcg_h = 'NDCG@100'
  c_ndcg_t = 'NDCG@1000'
  c_pre = 'P@10'
  c_pre_h = 'P@100'
  c_pre_t = 'P@1000'
  c_rec = 'Recall@10'
  c_rec_h = 'Recall@100'
  c_rec_t = 'Recall@1000'
  eval_dict = {
      c_map: [_map[c_map]], 
      c_map_h: [_map[c_map_h]],
      c_map_t: [_map[c_map_t]],
      c_ndcg: [ndcg[c_ndcg]], 
      c_ndcg_h: [ndcg[c_ndcg_h]],
      c_ndcg_t: [ndcg[c_ndcg_t]],
      c_pre: [precision[c_pre]], 
      c_pre_h: [precision[c_pre_h]],
      c_pre_t: [precision[c_pre_t]],
      c_rec: [recall[c_rec]],
      c_rec_h: [recall[c_rec_h]],
      c_rec_t: [recall[c_rec_t]]}
  eval_df = pd.DataFrame(data=eval_dict)
  eval_df.index = [model_name]
  return eval_df

DENSE - DistilBERT

In [4]:
model_distilbert_dot_cls = DRES(models.SentenceBERT('msmarco-distilbert-base-tas-b'), batch_size=128)
model_distilbert_dot_mean = DRES(models.SentenceBERT('msmarco-distilbert-base-dot-prod-v3'), batch_size=128)

retriever_distilbert_cls = EvaluateRetrieval(model_distilbert_dot_cls, score_function='dot')
retriever_distilbert_mean = EvaluateRetrieval(model_distilbert_dot_mean, score_function='dot')
results_distilbert_cls = retriever_distilbert_cls.retrieve(corpus, queries)
results_distilbert_mean = retriever_distilbert_mean.retrieve(corpus, queries)

ndcg_d_c, _map_d_c, recall_d_c, precision_d_c = retriever_distilbert_cls.evaluate(
    qrels, results_distilbert_cls, retriever_distilbert_cls.k_values)
ndcg_d_m, _map_d_m, recall_d_m, precision_d_m = retriever_distilbert_mean.evaluate(
    qrels, results_distilbert_mean, retriever_distilbert_mean.k_values)

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.95k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/547 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.35k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/554 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/341 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/376 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/115 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]



Batches:   0%|          | 0/41 [00:00<?, ?it/s]

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Batches:   0%|          | 0/41 [00:00<?, ?it/s]

In [5]:
eval_metrics('DistilBERT-CLS-dot', ndcg_d_c, _map_d_c, recall_d_c, precision_d_c)

Unnamed: 0,MAP@10,MAP@100,MAP@1000,NDCG@10,NDCG@100,NDCG@1000,P@10,P@100,P@1000,Recall@10,Recall@100,Recall@1000
DistilBERT-CLS-dot,0.59916,0.60459,0.60493,0.64276,0.66983,0.68106,0.08633,0.01013,0.00111,0.7615,0.891,0.98333


In [6]:
eval_metrics('DistilBERT-mean-dot', ndcg_d_m, _map_d_m, recall_d_m, precision_d_m)

Unnamed: 0,MAP@10,MAP@100,MAP@1000,NDCG@10,NDCG@100,NDCG@1000,P@10,P@100,P@1000,Recall@10,Recall@100,Recall@1000
DistilBERT-mean-dot,0.46872,0.47765,0.47821,0.51536,0.55669,0.57277,0.07233,0.00953,0.00109,0.64111,0.831,0.95933


end of fun.