# Index pruning

This file defines a set of all passage indices to keep in the pruned index. It uses __trained irrelevance classifier__ probabilities and __training data file__ which contains golden passages.

In [None]:
import h5py
import os
import numpy as np
from tqdm import tqdm

# In this example, the pruning probabilities are kept in <rootdir>/.pruning directory
os.chdir("../../..")
print(os.getcwd())

#### Change these paths if needed ########
total_passages=1_700_000 # total passages to kept

# set path to passage probabilities
file=".pruning/psgs_w100_irrelevant_passage_probs_electra_nqopen.h5"
#file=".pruning/psgs_w100_irrelevant_passage_probs_electra_trivia.h5"

# set training file path
training_data_source = ".data/nqopen/nq-open_train_short_maxlen_5_ms_with_dpr_annotation.jsonl"
#training_data_source = ".data/triviaopen/triviaqa-open_train_with_dpr_annotation.jsonl"

# here the set of all relevant passages will be saved
output_index_file = ".pruning/relevant_" +  \
                    os.path.basename(file)[:-3]+"_with_"+ \
                    os.path.basename(training_data_source)+ \
                    f"_p{total_passages}"+".pkl"

In [2]:
# Load passage probs/scores
data = h5py.File(file, 'r')['data'][()]
scores, probs = data[:,0], data[:,1]

### Lets choose the threshold, and compute, how many documents will be present with this classifier threshold

In [3]:
sorted_scores = np.sort(scores)
t = sorted_scores[total_passages]
print("Threshold is: " + str(t))

Threshold is: -2.7950857


In [4]:
print(f"Keeping {(scores<t).sum()} indices")
indices_to_keep =(scores<t).nonzero()

Keeping 1700000 indices


In [5]:
relevant_indices_to_keep = set(indices_to_keep[0].tolist())
print(list(relevant_indices_to_keep)[:20])

[0, 1, 2, 3, 4, 5, 4194305, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4194311, 17, 18, 4194318]


### Keep all training gold passages inside index

In [6]:
from scalingqa.index_pruning.dataset.NQ.build_dataset import get_golden_passages

positive_documents_gt_indices = get_golden_passages(training_data_source)
list(positive_documents_gt_indices)[:20]

[7077893,
 13631497,
 1572875,
 13238285,
 1310733,
 14024719,
 1703951,
 13631502,
 4849677,
 131092,
 6291477,
 19660825,
 6815773,
 14942239,
 262180,
 262187,
 9699372,
 1572909,
 17432619,
 20185135]

In [7]:
total_indices_to_keep = relevant_indices_to_keep.union(positive_documents_gt_indices)

In [8]:
print(f"Total size of kept index: {len(total_indices_to_keep)}")
print(f"Size chosen via binary classifier: {len(relevant_indices_to_keep)}")
print(f"Total size of training data index: {len(set(positive_documents_gt_indices))}")
print(f"Size of training data index missing from index kept via binary classifier: {len(total_indices_to_keep)-len(relevant_indices_to_keep)}")

Total size of kept index: 1702133
Size chosen via binary classifier: 1700000
Total size of training data index: 40670
Size of training data index missing from index kept via binary classifier: 2133


In [None]:
import pickle

print(f"Saving total indices after pruning as set into {output_index_file}")
with open(output_index_file, "wb") as f:
    pickle.dump(total_indices_to_keep, f)