# Set up BEIR

In [None]:
!nvidia-smi

Fri Mar 29 13:54:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P8               9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [1]:
!pip install beir

Collecting beir
  Downloading beir-2.0.0.tar.gz (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.6/53.6 kB[0m [31m158.0 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting sentence-transformers (from beir)
  Downloading sentence_transformers-2.6.1-py3-none-any.whl.metadata (11 kB)
Collecting pytrec_eval (from beir)
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting faiss_cpu (from beir)
  Downloading faiss_cpu-1.8.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (3.6 kB)
Collecting elasticsearch==7.9.1 (from beir)
  Downloading elasticsearch-7.9.1-py2.py3-none-any.whl.metadata (8.0 kB)
Collecting datasets (from beir)
  Downloading datasets-2.18.0-py3-none-any.whl.metadata (20 kB)
Collecting urllib3>=1.21.1 (from elasticsearch==7.9.1->beir)
  Using cached urllib3-2.2.1-py3-none-any.whl.metadata (6.4 kB)
Collecting certifi (from elasticsearch=

In [2]:
from beir import util, LoggingHandler

import logging
import pathlib, os

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

  from tqdm.autonotebook import tqdm


# Setup FinBert

In [None]:
!pip install transformers numpy torch



In [3]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from transformers import AutoModel, AutoTokenizer
import numpy as np
import torch
from tqdm import trange
import os
from typing import List, Dict

class FinBERT:
    def __init__(self, model_path: str, device, **kwargs):
        self.device = device
        self.bert_q = AutoModel.from_pretrained(model_path)
        # self.bert_q = AutoModel.from_pretrained(model_path, num_labels=3)
        self.bert_q.eval()
        self.bert_q.to(self.device)

        self.bert_d = AutoModel.from_pretrained(model_path)
        # self.bert_d = AutoModel.from_pretrained(model_path, num_labels=3)
        self.bert_d.eval()
        self.bert_d.to(self.device)

        self.tokenizer = AutoTokenizer.from_pretrained(model_path)


    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray:
        query_embeddings = []

        with torch.no_grad():
          for start_idx in trange(0, len(queries), batch_size):
            encoded = self.tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt', max_length=512)
            encoded.to(self.device)
            model_out = self.bert_q(**encoded)
            query_embeddings += model_out.last_hidden_state[:, 0, :].detach().cpu()

        return torch.stack(query_embeddings)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs) -> np.ndarray:
        corpus_embeddings = []

        with torch.no_grad():
          for start_idx in trange(0, len(corpus), batch_size):
            titles = [row['title'] for row in corpus[start_idx: start_idx + batch_size]]
            texts = [row['text']  for row in corpus[start_idx: start_idx + batch_size]]
            encoded = self.tokenizer(titles, texts, truncation='longest_first', padding=True, return_tensors='pt', max_length=512)
            encoded.to(self.device)
            model_out = self.bert_d(**encoded)
            corpus_embeddings += model_out.last_hidden_state[:, 0, :].detach().cpu()

        return torch.stack(corpus_embeddings)

2024-03-30 12:09:55 - PyTorch version 2.2.2 available.
2024-03-30 12:09:59 - Loading faiss.
2024-03-30 12:09:59 - Successfully loaded faiss.


# Setup Datasets

In [35]:
import pathlib, os
from beir import util

def download_dataset(dataset_name: str):
  url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset_name)
  out_dir = os.path.join(os.getcwd(), "datasets")
  data_path = util.download_and_unzip(url, out_dir)
  print("Dataset downloaded here: {}".format(data_path))
  return data_path

In [None]:
#!ls datasets/scifact/

In [40]:
from beir.datasets.data_loader import GenericDataLoader


#dataset = "scifact"
#dataset = "fiqa"
dataset="trec-covid"
data_path = download_dataset(dataset_name=dataset)
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev"

2024-04-01 16:44:57 - Downloading trec-covid.zip ...


/Users/alexmano/Documents/projects/information-retrieval/ir-cross-evaluations/datasets/trec-covid.zip: 100%|██████████| 70.5M/70.5M [00:01<00:00, 50.8MiB/s]


2024-04-01 16:44:59 - Unzipping trec-covid.zip ...
Dataset downloaded here: /Users/alexmano/Documents/projects/information-retrieval/ir-cross-evaluations/datasets/trec-covid
2024-04-01 16:45:00 - Loading Corpus...


100%|██████████| 171332/171332 [00:00<00:00, 261254.82it/s]

2024-04-01 16:45:01 - Loaded 171332 TEST Documents.
2024-04-01 16:45:01 - Doc Example: {'text': 'OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60%) were associated with pneumonia, 14 (35%) with upper respiratory tract 




# Evaluate

In [9]:
from beir.retrieval.evaluation import EvaluateRetrieval

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

finbert = DRES(FinBERT(model_path="yiyanghkust/finbert-tone", device="mps"), batch_size=16)
retriever = EvaluateRetrieval(finbert, score_function="dot")

results = retriever.retrieve(corpus, queries)

2024-03-30 12:30:37 - Encoding Queries...


100%|██████████| 19/19 [00:03<00:00,  6.09it/s]


2024-03-30 12:30:40 - Sorting Corpus by document length (Longest first)...
2024-03-30 12:30:40 - Scoring Function: Dot Product (dot)
2024-03-30 12:30:40 - Encoding Batch 1/1...


100%|██████████| 324/324 [02:11<00:00,  2.46it/s]


In [43]:
# Print a simple inference
#test_model = FinBERT(model_path="yiyanghkust/finbert-tone", device="mps")
#encoded = test_model.encode_queries(queries, batch_size=16)
#print(encoded.shape)

# Transform queries keys to int
#queries = {int(k): v for k, v in queries.items()}
queries.keys()

dict_keys(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50'])

In [10]:
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, results, retriever.k_values)
ndcg, _map, recall, precision

2024-03-30 12:32:52 - For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - NDCG@1: 0.0000
2024-03-30 12:32:52 - NDCG@3: 0.0000
2024-03-30 12:32:52 - NDCG@5: 0.0006
2024-03-30 12:32:52 - NDCG@10: 0.0016
2024-03-30 12:32:52 - NDCG@100: 0.0084
2024-03-30 12:32:52 - NDCG@1000: 0.0459
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - MAP@1: 0.0000
2024-03-30 12:32:52 - MAP@3: 0.0000
2024-03-30 12:32:52 - MAP@5: 0.0002
2024-03-30 12:32:52 - MAP@10: 0.0006
2024-03-30 12:32:52 - MAP@100: 0.0015
2024-03-30 12:32:52 - MAP@1000: 0.0024
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - Recall@1: 0.0000
2024-03-30 12:32:52 - Recall@3: 0.0000
2024-03-30 12:32:52 - Recall@5: 0.0011
2024-03-30 12:32:52 - Recall@10: 0.0044
2024-03-30 12:32:52 - Recall@100: 0.0400
2024-03-30 12:32:52 - Recall@1000: 0.3564
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - P@1: 0.0000
2024-03-30 12:32:52

({'NDCG@1': 0.0,
  'NDCG@3': 0.0,
  'NDCG@5': 0.00061,
  'NDCG@10': 0.00157,
  'NDCG@100': 0.00843,
  'NDCG@1000': 0.04585},
 {'MAP@1': 0.0,
  'MAP@3': 0.0,
  'MAP@5': 0.00022,
  'MAP@10': 0.00056,
  'MAP@100': 0.00147,
  'MAP@1000': 0.00241},
 {'Recall@1': 0.0,
  'Recall@3': 0.0,
  'Recall@5': 0.00111,
  'Recall@10': 0.00444,
  'Recall@100': 0.04,
  'Recall@1000': 0.35639},
 {'P@1': 0.0,
  'P@3': 0.0,
  'P@5': 0.00067,
  'P@10': 0.00067,
  'P@100': 0.0005,
  'P@1000': 0.0004})

In [11]:
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
top_k_accuracy = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="top_k_accuracy")
mrr, recall_cap, hole, top_k_accuracy

2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - MRR@1: 0.0000
2024-03-30 12:32:52 - MRR@3: 0.0000
2024-03-30 12:32:52 - MRR@5: 0.0007
2024-03-30 12:32:52 - MRR@10: 0.0010
2024-03-30 12:32:52 - MRR@100: 0.0020
2024-03-30 12:32:52 - MRR@1000: 0.0030
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - R_cap@1: 0.0000
2024-03-30 12:32:52 - R_cap@3: 0.0000
2024-03-30 12:32:52 - R_cap@5: 0.0011
2024-03-30 12:32:52 - R_cap@10: 0.0044
2024-03-30 12:32:52 - R_cap@100: 0.0400
2024-03-30 12:32:52 - R_cap@1000: 0.3564
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - Hole@1: 0.8233
2024-03-30 12:32:52 - Hole@3: 0.8789
2024-03-30 12:32:52 - Hole@5: 0.8953
2024-03-30 12:32:52 - Hole@10: 0.9220
2024-03-30 12:32:52 - Hole@100: 0.9581
2024-03-30 12:32:52 - Hole@1000: 0.9510
2024-03-30 12:32:52 - 

2024-03-30 12:32:52 - Accuracy@1: 0.0000
2024-03-30 12:32:52 - Accuracy@3: 0.0000
2024-03-30 12:32:52 - Accuracy@5: 0.0033
2024-03-30 12:32:52 - Accuracy@10: 0.0067
2024-03-30 12:32:52 - Accuracy@100: 0.0467
2024-03-30 1

({'MRR@1': 0.0,
  'MRR@3': 0.0,
  'MRR@5': 0.00067,
  'MRR@10': 0.001,
  'MRR@100': 0.00203,
  'MRR@1000': 0.003},
 {'R_cap@1': 0.0,
  'R_cap@3': 0.0,
  'R_cap@5': 0.00111,
  'R_cap@10': 0.00444,
  'R_cap@100': 0.04,
  'R_cap@1000': 0.35639},
 {'Hole@1': 0.82333,
  'Hole@3': 0.87889,
  'Hole@5': 0.89533,
  'Hole@10': 0.922,
  'Hole@100': 0.9581,
  'Hole@1000': 0.95102},
 {'Accuracy@1': 0.0,
  'Accuracy@3': 0.0,
  'Accuracy@5': 0.00333,
  'Accuracy@10': 0.00667,
  'Accuracy@100': 0.04667,
  'Accuracy@1000': 0.37667})