In [2]:
import faiss 
import numpy as np
from faiss.loader import swig_ptr
from tqdm.notebook import tqdm
import pandas as pd
import os
import time
from faiss.contrib.inspect_tools import get_invlist
#from run_inference import *

### Description 

This script contains the code to extract the data structures required by EMVB from a ```faiss``` index.

It assumes you already trained a ```faiss``` index on your collection and that you generated the query embeddings as well.
To do so, you need to 
1. encode the collection using a COLBERT (or JMPQ) model.
2. encode the queries as above.
3. build a faiss ivfpq index on the collection. 

For point 1 and 2, you can refer to the original github repos, [colbert](https://github.com/stanford-futuredata/ColBERT) and [jmpq](https://github.com/Suffoquer-fang/JMPQ). 

For point 3, you can use something like the following: 

In [9]:
'''
d = 128
collection_path = "" # If the collection is too big, you may need to sample it. Make sure to keep about 10% of the original data
training_set = np.load(collection_path)


ncentroids = # specify the number of centroids, somthing like 2**18
m = # specify the number of partitions 
nbits = 8
quantizer = faiss.IndexFlatL2(d)

index = faiss.IndexIVFPQ(quantizer, d, ncentroids, m, nbits)
index.train(training_set)

save_index_path = ""
faiss.write_index(index, save_index_path)
'''

'\nd = 128\ncollection_path = "" # If the collection is too big, you may need to sample it. Make sure to keep about 10% of the original data\ntraining_set = np.load(collection_path)\n\n\nncentroids = # specify the number of centroids, somthing like 2**18\nm = # specify the number of partitions \nnbits = 8\nquantizer = faiss.IndexFlatL2(d)\n\nindex = faiss.IndexIVFPQ(quantizer, d, ncentroids, m, nbits)\nindex.train(training_set)\n\nsave_index_path = ""\nfaiss.write_index(index, save_index_path)\n'

### Extraction

In [3]:
## Path to the faiss index

index_path = ""

In [4]:
# Path to the directory where the generated files will be saved

dest_dir = ""

In [None]:
!mkdir {dest_dir}

### Generate index decomposition

In [4]:
index_jmpq = faiss.read_index(index_path)

In [5]:
residuals = np.zeros([index_jmpq.ntotal, index_jmpq.pq.M], dtype= np.uint8)
all_indices = np.zeros([index_jmpq.ntotal], dtype= np.uint64)
centroids = index_jmpq.quantizer.reconstruct_n(0, index_jmpq.nlist)
centroids_to_pids = [None] * centroids.shape[0]

In [6]:
doclensArray = np.load("./external/msmarco/doclens_msmarco.npy")
tot_embedding = ## total number of embeddings in your collection. Usually it can be obtained from index_jmpq.ntotal 
n_docs = len(doclensArray)
emb2pid = np.zeros(tot_embedding, dtype = np.int64)
offset = 0;
for i in range(n_docs):
    l = doclensArray[i]
    emb2pid[offset: offset+l] = i
    offset = offset + l
doc_offsets = np.zeros(n_docs, dtype=np.int64)
for i in range(1, n_docs):
    doc_offsets[i] = doc_offsets[i-1] + doclensArray[i-1]
    

In [7]:
for i in tqdm(range(index_jmpq.nlist)): 
    ids, codes = get_invlist(index_jmpq.invlists, i)
    residuals[ids] = codes
    all_indices[ids] = i
    centroids_to_pids[i] = emb2pid[ids]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 262144/262144 [01:21<00:00, 3217.93it/s]


In [8]:
# Write centroids to pids 
with open(os.path.join(dest_dir, "centroids_to_pids.txt"), "w") as file:
    for centroids_list in tqdm(centroids_to_pids):
        for x in centroids_list:
            file.write(f"{x} ")
        file.write("\n")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 262144/262144 [03:38<00:00, 1197.27it/s]


In [9]:
# Write residuals
np.save(os.path.join(dest_dir, "residuals.npy"), residuals)

In [10]:
# Write centroids
np.save(os.path.join(dest_dir, "centroids.npy"), centroids)

In [11]:
# Write index_assignments
np.save(os.path.join(dest_dir, "index_assignment.npy"), all_indices)

In [12]:
# Write pq_centroids
pq_centroids = faiss.vector_to_array(index_jmpq.pq.centroids)
np.save(os.path.join(dest_dir, "pq_centroids.npy"), pq_centroids)

##### Query embeddings

As a final step, you need to copy the query embeddings into ```dest_dir```, for example by using bash. 

- If you are running a simple PQ algorithm, then you just need to copy the query_embeddings as they are.
- If you are running OPQ, you need to rotate the queries before copying. The rotation time is not included in the time measurements of EMVB as it is negligible compared to the search time. 
- If you are runnin JMPQ, you need to re-encode the queries as JMPQ fine-tunes also the query encoder