# 0. load libraries

In [3]:
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from tqdm.autonotebook import tqdm
import os, gzip, json
from datasets import load_dataset
from tqdm import tqdm
import numpy as np


# 1. download hard negative passages of msmarco mined by sentence-transformers

In [4]:
triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
data_path = "/workspace/mnt2/dpr_datasets/msmarco/sbert"
msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
if not os.path.isfile(msmarco_triplets_filepath):
    util.download_url(triplets_url, msmarco_triplets_filepath)

In [None]:
dataset = "msmarco"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = "/workspace/mnt2/dpr_datasets/msmarco/beir/msmarco"
data_path = util.download_and_unzip(url, out_dir)
corpus, queries, _ = GenericDataLoader(data_path).load(split="train")

# 2. select the hard negative passages which has cross encoder score lower than positive passages - 3 

In [None]:
ce_score_margin = 3
num_negs_per_system = 10
train_queries = {}
not_selected_samples = []
cnt=0
with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
    for line in tqdm(fIn, total=502939):
        not_selected_samples.append(cnt)
        cnt = 0
        data = json.loads(line)
        
        #Get the positive passage ids
        pos_pids = [item['pid'] for item in data['pos']]
        pos_min_ce_score = min([item['ce-score'] for item in data['pos']])
        ce_score_threshold = pos_min_ce_score - ce_score_margin
        
        #Get the hard negatives
        neg_pids = set()

        if 'bm25' not in data['neg']:
            continue
        system_negs = data['neg']['bm25']
        negs_added = 0
        for item in system_negs:
            if item['ce-score'] > ce_score_threshold:
                cnt += 1
                continue

            pid = item['pid']
            if pid not in neg_pids:
                neg_pids.add(pid)
                negs_added += 1
                if negs_added >= num_negs_per_system:
                    break
        
        if len(pos_pids) > 0 and len(neg_pids) > 0:
            train_queries[data['qid']] = {
                'query': queries[data['qid']], 
                'pos': pos_pids, 
                'hard_neg': list(neg_pids)}
        

## 3. Preprocess the hard negative passages with the original msmarco data

In [None]:
# it took more than 40 minutes to download the dataset
corpus = load_dataset('BeIR/msmarco', 'corpus', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')
query = load_dataset('BeIR/msmarco', 'queries', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original')
qrels = load_dataset('BeIR/msmarco-qrels', cache_dir='/workspace/mnt2/dpr_datasets/msmarco/original') # train/validation/test

In [None]:
queries = {}

for line in tqdm(query['queries']):
    queries[line['_id']] = line

corpus_ = {}

for line in tqdm(corpus['corpus']):
    corpus_[line['_id']] = line

In [None]:
msmarco = []
for qid, qrel in tqdm(train_queries.items()):
    data = {}
    data['dataset'] = 'msmarco'
    data['question'] = {'text' : qrel['query']}
    data['positive_ctxs'] = [corpus_[pid] for pid in qrel['pos']]
    data['negative_ctxs'] = [corpus_[pid] for pid in qrel['hard_neg']]
    msmarco.append(data)


In [10]:
with open('/workspace/mnt2/dpr_datasets/msmarco/preprocessed/msmarco_train_filtered.json', 'w') as f:
    json.dump(msmarco, f, indent=4)