In [2]:
from loguru import logger
import pyarrow as pa
import faiss
from tqdm import tqdm
import numpy as np
import time

In [5]:
# Specify These
config = 'MR' # MR, RR
index_type = 'IVFOPQ' # IP, L2, IVF, OPQ, IVFOPQ
train_batches = False # Set to True if system has sufficient RAM
DPR_root = '/mnt/disks/experiments/DPR/'

In [8]:
if config == 'MR':
    config_name = 'dpr-nq-d768_384_192_96_48-wiki' # MR
else:
    config_name = 'dpr-nq-d768-wiki' # RR-768
    
embeddings_file = f'{DPR_root}results/embed/{config_name}.arrow'
emb_data = pa.ipc.open_file(pa.memory_map(embeddings_file, "rb")).read_all()

## Batched Index Training (RAM-constrained)
Learn Exact Search Indices (with IP distance) in batches over 21M passages.

In [7]:
# Train exact index with database and queries with embedding size 'dim' and write to disk
def batched_train(dim):   
    index_file = f'results/embed/exact-index/{config_name}-dim{dim}_{index_type}_batched.faiss'
    
    sub_index = faiss.IndexFlatIP(dim)
    faiss_index = faiss.IndexIDMap2(sub_index)

    total = 0
    for batch in tqdm(emb_data.to_batches()):
        batch_data = batch.to_pydict()
        psg_ids = np.array(batch_data["id"])

        token_emb = np.array(batch_data["embedding"], dtype=np.float32)
        token_emb = np.ascontiguousarray(token_emb[:, :dim]) # Shape: (8192, dim)
        faiss_index.add_with_ids(token_emb, psg_ids)

        total += len(psg_ids)
        if total % 1000 == 0:
            logger.info(f"indexed {total} passages")

    faiss.write_index(faiss_index, str(index_file))

if(train_batches):
    batched_train(dim=768)

## Full Training (High peak RAM Usage ~120G)

In [3]:
if not train_batches:
    psg_ids = np.array(emb_data['id'])
    print(psg_ids.shape) # Passage IDs

    # Takes ~5 min on our system
    token_emb =  np.array(emb_data["embedding"])

    token_emb = np.hstack(token_emb)

    token_emb = token_emb.reshape(21015324, -1)
    print(token_emb.shape, token_emb.dtype) # Token Embeddings
else:
    raise Exception("Insufficient RAM to train on entire data!")

(21015324,)
(21015324, 768) float32


In [6]:
ncell=10 # Number of IVF cells
dims=[768] # Embedding dims to train indices over
Ms=[8, 16, 32, 48, 64, 96] # Number of PQ sub-quantizers for IVF+OPQ

for M in Ms:
    for dim in dims:
        if M > dim or dim%M!=0:
            print("Skipping (d,M) : (%d, %d)" %(dim, M))
            continue
        
        token_emb_sliced = np.ascontiguousarray(token_emb[:, :dim])
        faiss.normalize_L2(token_emb_sliced)
        print("Adding DB: ", token_emb_sliced.shape)
        print(f'Generating {index_type} index on config: {config_name}')
        
        tic = time.time()
        # Flat L2 Index
        if index_type == 'IP':
            index_file = f'results/embed/IP/{config_name}-dim{dim}_IP.faiss'
            sub_index = faiss.IndexFlatIP(dim)
            faiss_index = faiss.IndexIDMap2(sub_index)

        elif index_type == 'L2':
            index_file = f'results/embed/L2/{config_name}-dim{dim}_L2.faiss'
            sub_index = faiss.IndexFlatL2(dim)
            faiss_index = faiss.IndexIDMap2(sub_index)

        elif index_type == 'IVF':
            index_file = f'results/embed/IVF/{config_name}-dim{dim}_IVF_ncell{ncell}.faiss'
            quantizer = faiss.IndexFlatL2(dim)
            faiss_index = faiss.IndexIVFFlat(quantizer, dim, ncell)
            faiss_index.train(token_emb_sliced)
            
        elif index_type == 'OPQ':
            index_file = f'results/embed/OPQ/{config_name}-dim{dim}_OPQ_M{M}_nbits8.faiss'
            opq_train_db_indices = np.random.choice(token_emb_sliced.shape[0], 500000, replace=False)
            opq_train_db = token_emb_sliced[opq_train_db_indices]
            sub_index = faiss.index_factory(dim, f"OPQ{M},PQ{M}x{8}")
            faiss_index = faiss.IndexIDMap2(sub_index)
            faiss_index.train(opq_train_db)

        elif index_type == 'IVFOPQ':
            index_file = f'results/embed/IVFOPQ/{config_name}-dim{dim}_IVFOPQ_cell{ncell}_M{M}_nbits8.faiss'
            sub_index = faiss.index_factory(dim, f"OPQ{M},IVF{ncell},PQ{M}x{8}")
            faiss_index = faiss.IndexIDMap2(sub_index)
            faiss_index.train(token_emb_sliced)
        
        faiss_index.add_with_ids(token_emb_sliced, psg_ids)
        faiss.write_index(faiss_index, str(index_file))
        toc = time.time()
        
        print("Generated ", index_file)
        print("Time to build index with d=%d : %f" %(dim, toc-tic))

Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M8.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 1723.532558
Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 1926.746812
Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M32.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 2302.668599
Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M48.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 2746.178078
Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M64.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 2214.350993
Generating dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M96.faiss
Adding DB:  (21015324, 768)
Time to build index with d=768 : 2294.547853


# Search (restart kernel for memory)

In [2]:
%%bash
split=test
ds=nq

# Change these
d=768
index_type=IVFOPQ
config_name=dpr-nq-d768_384_192_96_48-wiki

# Modify index_file name to the one built above
for M in 16 32
do
    for d in 768
    do
        python rtr/cli/eval_retriever.py \
        --passage_db_file data/psgs-w100.lmdb \
        --model_ckpt ckpt/{config_name} \
        --index_file results/embed/${index_type}/${config_name}-dim${d}_${index_type}_cell${ncell}_M${M}.faiss \
        --dataset_file qas-data/${ds}-${split}.csv \
        --save_file results/json/reader-${config_name}-${ds}-${split}-dim${d}.jsonl \
        --batch_size 512 \
        --max_question_len 200 \
        --embedding_size ${d} \
        --metrics_file results/metrics.json \
        --binary False \
        2>&1 | tee results/logs/eval-${config_name}-${ds}-${split}-dim${d}.log
        echo -e "Finished Processing!\n"
    done
done

2023-05-21 20:04:27.733 | INFO     | __main__:batch_eval_dataset:94 - init Retriever from model_ckpt=ckpt/dpr-nq-d768_384_192_96_48
2023-05-21 20:04:35.608 | INFO     | __main__:batch_eval_dataset:100 - loading index_file=results/embed/IVFOPQ/dpr-nq-d768-wiki-dim768_IVFOPQ_cell10_M16.faiss...
2023-05-21 20:04:41.547 | INFO     | __main__:batch_eval_dataset:107 - loading passage_db_file=data/psgs-w100.lmdb...
2023-05-21 20:04:41.758 | INFO     | __main__:batch_eval_dataset:114 - loading QA pairs from qas-data/nq-test.csv
2023-05-21 20:04:41.803 | INFO     | __main__:batch_eval_dataset:119 - computing query embeddings...
2023-05-21 20:04:41.804 | INFO     | __main__:batch_eval_dataset:121 - begin searching max(top_k)=200 passage for 3610 question...
search 1668.9 queries/s, checking answers: 100%|██████████| 8/8 [10:52<00:00, 81.59s/it]
2023-05-21 20:15:34.546 | INFO     | __main__:batch_eval_dataset:154 - #total examples: 3610
2023-05-21 20:15:34.567 | INFO     | __main__:batch_eval_dat