In [None]:
!pip install pyserini==0.25.0 pytrec_eval datasets tqdm
!pip install faiss-cpu

Highly recommended to run the code on Google Colab: https://colab.research.google.com/drive/1Mdo5yRB1Sz4nJ5gtNbe1y9DJz8TPDfBv?usp=sharing
The code partially from https://huggingface.co/datasets/intfloat/query2doc_msmarco/blob/main/repro_bm25.py

In [None]:
import faiss
import urllib.request
import json
import tqdm
import pytrec_eval
import random
import numpy as np

from typing import Dict, Tuple
from datasets import load_dataset
from pyserini.search import SimpleSearcher

In [None]:
def calculate_means(my_results):
  means = {metric: np.mean([d[metric] for d in my_results]) for metric in my_results[0]}

  # Calculate standard deviations
  std_devs = {metric: np.std([d[metric] for d in my_results]) for metric in my_results[0]}

  # Print means
  print("Means:")
  for metric, mean_value in means.items():
      print(f"{metric}: {mean_value}")

  # Print standard deviations
  print("\nStandard Deviations:")
  for metric, std_value in std_devs.items():
      print(f"{metric}: {std_value}")

def trec_eval(qrels: Dict[str, Dict[str, int]],
              results: Dict[str, Dict[str, float]],
              k_values: Tuple[int] = (10, 50, 100, 200, 1000)) -> Dict[str, float]:
    ndcg, _map, recall = {}, {}, {}

    for k in k_values:
        ndcg[f"NDCG@{k}"] = 0.0
        _map[f"MAP@{k}"] = 0.0
        recall[f"Recall@{k}"] = 0.0

    map_string = "map_cut." + ",".join([str(k) for k in k_values])
    ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
    recall_string = "recall." + ",".join([str(k) for k in k_values])

    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string})
    scores = evaluator.evaluate(results)

    for query_id in scores:
        for k in k_values:
            ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
            _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
            recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]

    def _normalize(m: dict) -> dict:
        return {k: round(v / len(scores), 5) for k, v in m.items()}

    ndcg = _normalize(ndcg)
    _map = _normalize(_map)
    recall = _normalize(recall)

    all_metrics = {}
    for mt in [ndcg, _map, recall]:
        all_metrics.update(mt)

    return all_metrics


def load_qrels_from_file(file_path: str) -> Dict[str, Dict[str, int]]:
    qrels = {}
    with open(file_path, 'r') as file:
        for line in file.readlines():
            qid, _, pid, score = line.strip().split()
            if qid not in qrels:
                qrels[qid] = {}
            qrels[qid][pid] = int(score)
    print('Load {} queries {} qrels from {}'.format(len(qrels), sum(len(v) for v in qrels.values()), file_path))
    return qrels

In [None]:
query2doc_dataset = load_dataset('intfloat/query2doc_msmarco')['train']
with open('long_tail_queries.txt', 'r') as file:
  ids = [line.strip() for line in file]
random_ids_long_tail = [example['query_id'] for example in query2doc_dataset if example['query_id'] in ids]
random_ids_common = [example['query_id'] for example in query2doc_dataset if example['query_id'] not in ids]

In [None]:
my_results_common = []
def main(split: str = 'train'):
    searcher: SimpleSearcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')

    for i in range(5):
      # random_ids_my_common = random.sample(random_ids_common, k=50)
      random_ids_long_my_tail = random.sample(random_ids_long_tail, k=50)
      query2doc_my_dataset = query2doc_dataset.filter(lambda example: example['query_id'] in random_ids_long_my_tail)
      queries = []
      for idx in range(len(query2doc_my_dataset)):
          example = query2doc_my_dataset[idx]
          new_query = '{} {}'.format(' '.join([example['query'] for _ in range(5)]), example['pseudo_doc'])
          queries.append(new_query)
      print('Load {} queries'.format(len(queries)))

      results: Dict[str, Dict[str, float]] = {}
      batch_size = 64
      num_batches = (len(queries) + batch_size - 1) // batch_size
      for i in tqdm.tqdm(range(num_batches), mininterval=2):
          batch_query_ids = query2doc_my_dataset['query_id'][i * batch_size: (i + 1) * batch_size]
          batch_queries = queries[i * batch_size: (i + 1) * batch_size]
          qid_to_hits: dict = searcher.batch_search(batch_queries, qids=batch_query_ids, k=1000, threads=8)
          for qid, hits in qid_to_hits.items():
              results[qid] = {hit.docid: hit.score for hit in hits}

      qrels_file_path = 'qrels_train.tsv'
      qrels = load_qrels_from_file(qrels_file_path)

      all_metrics = trec_eval(qrels=qrels, results=results)
      my_results_common.append(all_metrics)

    # print('Evaluation results for {} split:'.format(split))
    # print(json.dumps(all_metrics, ensure_ascii=False, indent=4))
    calculate_means(my_results_common)


if __name__ == '__main__':
    main(split='train')

In [None]:
from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder


def main(split: str = 'train'):
    encoder = TctColBertQueryEncoder('castorini/tct_colbert-msmarco')
    searcher = FaissSearcher.from_prebuilt_index(
        'msmarco-passage-tct_colbert-hnsw',
        encoder
    )
    # random_ids_my_common = random.sample(random_ids_common, k=50)
    random_ids_long_my_tail = random.sample(random_ids_long_tail, k=50)
    query2doc_my_dataset = query2doc_dataset.filter(lambda example: example['query_id'] in random_ids_long_my_tail)

    queries = []
    for idx in range(len(query2doc_my_dataset)):
        example = query2doc_my_dataset[idx]
        new_query = '{} {}'.format(' '.join([example['query'] for _ in range(5)]), example['pseudo_doc'])
        queries.append(new_query)
    print('Load {} queries'.format(len(queries)))

    results: Dict[str, Dict[str, float]] = {}
    batch_size = 64
    num_batches = (len(queries) + batch_size - 1) // batch_size
    for i in tqdm.tqdm(range(num_batches), mininterval=2):
        batch_query_ids = query2doc_my_dataset['query_id'][i * batch_size: (i + 1) * batch_size]
        batch_queries = queries[i * batch_size: (i + 1) * batch_size]
        qid_to_hits: dict = searcher.batch_search(batch_queries, qids=batch_query_ids, k=1000, threads=8)
        for qid, hits in qid_to_hits.items():
            results[qid] = {hit.docid: hit.score for hit in hits}
    qrels_file_path = 'qrels_train.tsv'
    qrels = load_qrels_from_file(qrels_file_path, ids)
    all_metrics = trec_eval(qrels=qrels, results=results)

    print('Evaluation results for {} split:'.format(split))
    print(json.dumps(all_metrics, ensure_ascii=False, indent=4))


if __name__ == '__main__':
    main(split='train')

In [None]:
my_results_common = []
def main(split: str = 'train'):
    searcher: SimpleSearcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')

    for i in range(5):
      random_ids_my_common = random.sample(random_ids_common, k=50)
      # random_ids_long_my_tail = random.sample(random_ids_long_tail, k=50)
      query2doc_my_dataset = query2doc_dataset.filter(lambda example: example['query_id'] in random_ids_my_common)
      queries = []
      for idx in range(len(query2doc_my_dataset)):
          example = query2doc_my_dataset[idx]
          new_query = '{}'.format(' '.join([example['query'] for _ in range(5)]))
          queries.append(new_query)
      print('Load {} queries'.format(len(queries)))

      results: Dict[str, Dict[str, float]] = {}
      batch_size = 64
      num_batches = (len(queries) + batch_size - 1) // batch_size
      for i in tqdm.tqdm(range(num_batches), mininterval=2):
          batch_query_ids = query2doc_my_dataset['query_id'][i * batch_size: (i + 1) * batch_size]
          batch_queries = queries[i * batch_size: (i + 1) * batch_size]
          qid_to_hits: dict = searcher.batch_search(batch_queries, qids=batch_query_ids, k=1000, threads=8)
          for qid, hits in qid_to_hits.items():
              results[qid] = {hit.docid: hit.score for hit in hits}

      qrels_file_path = 'qrels_train.tsv'
      qrels = load_qrels_from_file(qrels_file_path)

      all_metrics = trec_eval(qrels=qrels, results=results)
      my_results_common.append(all_metrics)

    # print('Evaluation results for {} split:'.format(split))
    # print(json.dumps(all_metrics, ensure_ascii=False, indent=4))
    calculate_means(my_results_common)


if __name__ == '__main__':
    main(split='train')