In [1]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64" # Make sure this is the sama JAVA_HOME as the installed version on the previous notebook!
data_home = "/ssd2/arthur/MsMarcoTREC/"
def path(x):
    return os.path.join(data_home, x)

try:
    import pyserini
except:
    !pip install pyserini==0.8.1.0 # install pyserini
try:
    import tqdm
except:
    !pip install tqdm # Good for progress bars!

In [2]:
import jnius_config
jnius_config.add_options('-Xmx32G')
from pyserini.search import pysearch
import subprocess
from tqdm.auto import tqdm
import random
import pickle
import sys
import unicodedata
import string
import re
import os
from collections import defaultdict
import math

In [3]:
index_path = path("lucene-index.msmarco-doc.pos+docvectors+rawdocs")
searcher = pysearch.SimpleSearcher(index_path)
relevant_docs = defaultdict(lambda:[])
for file in [path("qrels/msmarco-doctrain-qrels.tsv"), path("qrels/msmarco-docdev-qrels.tsv")]:
    for line in open(file):
        query_id, _, doc_id, rel = line.split()
        assert rel == "1"
        relevant_docs[query_id].append(doc_id)                            

In [4]:
pattern = re.compile('([^\s\w]|_)+')
printable = set(string.printable)

anserini_top_10 = defaultdict(lambda:[])
searcher.set_bm25_similarity(0.9, 0.4)
pairs_per_split = defaultdict(lambda: [])
threads = 50 # Number of Threads to use when retrieving
k = 10       # Number of documents to retrieve 
neg_samples = 2 # Number of negatives samples to use
batch_size = 10000 # Batch size for each retrieval step on Anserini

query_texts = dict()
for split in ["train", "dev"]:
    file_path = path(f"queries/msmarco-doc{split}-queries.tsv")
    run_search=True
    if os.path.isfile(file_path):
        print(f"Already found file {file_path}. Cowardly refusing to run this again. Will only load querytexts.")
        pairs_per_split[split] = pickle.load(open(path(f"{split}_triples.pkl"), 'rb'))
        run_search = False
    number_of_queries = int(subprocess.run(f"wc -l {file_path}".split(), capture_output=True).stdout.split()[0])
    number_of_batches = math.ceil(number_of_queries/batch_size)
    pbar = tqdm(total=number_of_batches, desc="Retrieval batches")
    queries = []
    query_ids = []
    for idx, line in enumerate(open(file_path, encoding="utf-8")):
        query_id, query = line.strip().split("\t")
        query_ids.append(query_id)
        query = unicodedata.normalize("NFKD", query) # Force queries into UTF-8
        query = pattern.sub(' ',query) # Remove non-ascii characters. It clears up most of the issues we may find on the query datasets
        queries.append(query)
        if len(queries) == batch_size or idx == number_of_queries-1:
            results = searcher.batch_search(queries, query_ids, k=k, threads=threads)
            pbar.update()
            for query, query_id in zip(queries, query_ids):
                retrieved_docs_ids = [hit.docid for hit in results[query_id]]
                relevant_docs_for_query = relevant_docs[query_id]
                retrieved_non_relevant_documents = set(retrieved_docs_ids).difference(set(relevant_docs_for_query))
                  
                if len(retrieved_non_relevant_documents) < 2:
                    print(f"query {query} has less than 2 retrieved docs.")
                    continue
                random_negative_samples = random.sample(retrieved_non_relevant_documents, neg_samples)
                pairs_per_split[split] += [(query_id, doc_id, 1) for doc_id in relevant_docs_for_query]
                pairs_per_split[split] += [(query_id, doc_id, 0) for doc_id in random_negative_samples]
            queries = []
            query_ids = []
    pickle.dump(pairs_per_split[split], open(path(f"{split}_triples.pkl"), 'wb'))
    pbar.close()


Already found file /ssd2/arthur/MsMarcoTREC/queries/msmarco-doctrain-queries.tsv. Cowardly refusing to run this again. Will only load querytexts.


HBox(children=(FloatProgress(value=0.0, description='Retrieval batches', max=37.0, style=ProgressStyle(descrip…

query standingdefinition has less than 2 retrieved docs.
query a pantsman has less than 2 retrieved docs.

Already found file /ssd2/arthur/MsMarcoTREC/queries/msmarco-docdev-queries.tsv. Cowardly refusing to run this again. Will only load querytexts.


HBox(children=(FloatProgress(value=0.0, description='Retrieval batches', max=1.0, style=ProgressStyle(descript…


