In [1]:
import faiss
import numpy as np
import time
import csv
from os import path, makedirs

import multiprocessing
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial

from faiss.contrib.ivf_tools import add_preassigned, search_preassigned

## AdANNS-IVF

### Notation
1. $D$ = Embedding Dimensionality for IVF construction and search
2. $M$ = number of OPQ subquantizers. Faiss requires $D$ % $M$ == $0$. 
3. For AdANNS, D is decomposed to $D_{construct}$ and $D_{search}$

### Miscellaneous Notes
1. Rigid representations (RR) are embedded through independently trained "fixed feature" (FF) encoders. RR and FF are thus used interchangeably in documentation and code and are essentially equivalent.
2. In this notebook, the AdANNS-IVF coarse quantizer uses OPQ by default for cheap distance computation, but is <u>optional</u>.
3. AdANNS-IVF is adapted from this [Faiss Case Study](https://gist.github.com/mdouze/8c5ab227c0f7d9d7c15cf92a391dcbe5#file-demo_independent_ivf_dimension-ipynb)
4. Optimized AdANNS-IVF (with Faiss) has a restriction that $D_{construct}\geq D_{search}$. This is because we slice centroids learnt from $D_{construct}$ to learn PQ codebooks with $D_{search}$ (this is possible because they are MRs)

In [9]:
D = 2048 # Max d for ResNet50
n_cell = 1024 # number of IVF cells, default=1024 for ImageNet-1K

embeddings_root = 'path/to/embeddings' # load embeddings
adanns_root = 'path/to/adanns/indices/' # store adanns indices
rigid_root = 'path/to/rigid/indices/' # store rigid indices
config = 'rr' # mrl, rr

if config == 'mrl':
    config_load = 'mrl1_e0_ff2048'
elif config == 'rr':
    config_load = 'mrl0_e0_ff2048'
else:
    raise Exception(f"Unsupported config {config}!")

use_mrl = config.upper() # MRL, RR

db_npy = '1K_train_' + config_load + '-X.npy'
query_npy = '1K_val_' + config_load + '-X.npy'

In [None]:
xb = np.load(embeddings_root + db_npy)
assert np.count_nonzero(np.isnan(xb)) == 0
xq = np.load(embeddings_root + query_npy)

query_labels = np.load(embeddings_root + "1K_val_" + config_load + "-y.npy")
db_labels = np.load(embeddings_root + "1K_train_" + config_load + "-y.npy")

print("loaded DB %s : %s" % (db_npy, xb.shape))
print("loaded queries %s : %s" % (query_npy, xq.shape))

## RR2048 OPQ Dim Reduction Baseline

In [11]:
db_subsampled = xb[np.random.choice(xb.shape[0], 100000, replace=False)]
print(db_subsampled.shape)
dim_reduce = 128

(100000, 2048)


### SVD dim reduction + OPQ

In [12]:
def get_SVD_mat(db_subsampled, low_dim):
    mat = faiss.PCAMatrix(db_subsampled.shape[1], low_dim)
    mat.train(db_subsampled)
    assert mat.is_trained
    return mat

svd_mat = get_SVD_mat(db_subsampled, dim_reduce)
database_svd_lowdim = svd_mat.apply(xb)
print("SVD projected Database: ", database_svd_lowdim.shape)
query_svd_lowdim = svd_mat.apply(xq)
print("SVD projected Queries: ", query_svd_lowdim.shape)

faiss.normalize_L2(database_svd_lowdim)
faiss.normalize_L2(query_svd_lowdim)

SVD projected Database:  (1281167, 128)
SVD projected Queries:  (50000, 128)


In [13]:
for M in [128]:
    if not path.exists(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss'):
        print(f'Building {use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')
        cpu_index = faiss.index_factory(dim_reduce, f'OPQ{M},PQ{M}')
        start = time.time()
        cpu_index.train(database_svd_lowdim)
        cpu_index.add(database_svd_lowdim)
        print("Train+add time: ", time.time() - start)
        faiss.write_index(cpu_index, f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}_OPQ{M}.faiss')
        
        top1 = [xb.shape[1], dim_reduce, M]
        _, Ind = cpu_index.search(query_svd_lowdim, 100)
        top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])
        print(top1)

Building FF_D2048_SVD128_OPQ128.faiss
Train+add time:  13411.728565454483
[2048, 128, 128, 0.69224]


In [None]:
for M in [128]:
    top1 = [xb.shape[1], dim_reduce, M]
    svd_opq_index = faiss.read_index(f'{rigid_root}/SVD_dimreduce/{use_mrl}_D2048_SVD{dim_reduce}.faiss')
    _, Ind = svd_opq_index.search(query_svd_lowdim, 100)
    top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])
    print(top1)

[2048, 128, 16, 0.69088]


## Rigid-IVF + OPQ

In [19]:
# Construct Rigid Index

for M in [16]:
    for D in [2048]:
        database = np.ascontiguousarray(xb[:,:D], dtype=np.float32)
        faiss.normalize_L2(database)
        
        if M > D:
            continue

        if not path.exists(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss'):
            print(f'Building {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')
            start = time.time()

            index = faiss.index_factory(int(D), f'IVF{n_cell},PQ{M}')

            opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')
            opq = opq_index_pretrained.chain.at(0)

            db = opq.apply(database)

            index.train(db)
            index.add(db)

            print("Time: ", time.time() - start)
            faiss.write_index(index, f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')
            print(f'Created IVF{n_cell},OPQ{M} index with D={D}')

        else:
            print(f'Skipping build, index exists: {use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')

Skipping build, index exists: FF_D2048+IVF1024,OPQ16.faiss


In [23]:
# Search Rigid Index

print('[n_cell, D, M, top1]')
for D in [2048]:
    queryset = np.ascontiguousarray(xq[:,:D], dtype=np.float32)
    faiss.normalize_L2(queryset)
    for M in [8, 16, 32, 64]:
        if M > D:
            continue
        
        top1 = [n_cell, D, M]
        times = [n_cell, D, M]

        index = faiss.read_index(f'{rigid_root}/IVFOPQ/{use_mrl}_D{D}+IVF{n_cell},OPQ{M}.faiss')
       
        opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D}_nbits8.index')
        opq = opq_index_pretrained.chain.at(0)

        q = opq.apply(queryset)

        for nprobe in [1]:
            start = time.time()
            faiss.extract_index_ivf(index).nprobe = nprobe 
            Dist, Ind = index.search(q, 100)

            top1.append((np.sum(db_labels[Ind[:, 0]] == query_labels)) / query_labels.shape[0])
            times.append(time.time() - start)

        print(top1)

[n_cell, D, M, top1]
[1024, 2048, 8, 0.64966]
[1024, 2048, 16, 0.6663]
[1024, 2048, 32, 0.67724]
[1024, 2048, 64, 0.68588]


## AdANNS-IVF + OPQ

In [45]:
def create_adanns_indices(D_search, D_construct, M, n_cell):
    index_search = faiss.index_factory(D_search, f'OPQ{M},IVF{n_cell},PQ{M}')
    index_construct = faiss.index_factory(D_construct, f'IVF{n_cell},Flat')
    
    database = np.ascontiguousarray(xb[:,:D_construct], dtype=np.float32)
    faiss.normalize_L2(database)

    # train the full-dimensional "construct" coarse quantizer. IVF centroid assignments are learnt with D_construct
    if not path.exists(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss'):
        index_construct.train(database)
        quantizer_construct = index_construct.quantizer
        faiss.write_index(quantizer_construct, adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')
    else:
        print("Index exists: ", adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')

    # prepare the "search" coarse quantizer. OPQ codebooks are learnt on D_search
    if not path.exists(adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss'):
        quantizer_construct = faiss.read_index(adanns_root+f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_construct_quantizer.faiss')
        database_search = np.ascontiguousarray(xb[:, :D_search], dtype=np.float32)
        centroids_search = np.ascontiguousarray(quantizer_construct.reconstruct_n(0, quantizer_construct.ntotal)[:, :D_search], dtype=np.float32)
        
        # Apply OPQ to search DB and centroids
        opq_index_pretrained = faiss.read_index(f'{embeddings_root}index_files/{config}/opq/1K_opq_{M}m_d{D_search}_nbits8.index')
        print(f'Applying OPQ: 1K_opq_{M}m_d{D_search}')
        opq = opq_index_pretrained.chain.at(0)
        opq.apply(centroids_search)
        opq.apply(database_search)
        faiss.normalize_L2(database_search)
        
        index_ivf_search = faiss.downcast_index(faiss.extract_index_ivf(index_search))
        index_ivf_search.quantizer.add(centroids_search)

        index_ivf_search.train(database_search)
        index_search.is_trained = True

        # coarse quantization with the construct quantizer
        _, Ic = quantizer_construct.search(database, 1) # each database vector assigned to one of num_cell centroids
        # add operation 
        add_preassigned(index_ivf_search, database_search, Ic.ravel())

        faiss.write_index(index_ivf_search, adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')
    else:
        print("Index exists: ", adanns_root+f'MRL_Dsearch{D_search}_Dconstruct{D_construct}+IVF{n_cell},OPQ{M}_search_quantizer.faiss')
    
    print(f'Initialized construct quantizer D{D_construct}, search quantizer D{D_search}, M{M}, ncell{n_cell}')

In [46]:
for D_construct in [64, 128, 256, 512, 1024, 2048]:
    for D_search in [2048]:
        for M in [64]:
            if M > D_search or D_search > D_construct:
                print("Skipping (M, d_search, d_construct): ", (M, D_search, D_construct))
                continue
            create_adanns_indices(D_search, D_construct, M, n_cell=1024)

Skipping (M, d_small, d_big):  (64, 2048, 64)
Skipping (M, d_small, d_big):  (64, 2048, 128)
Skipping (M, d_small, d_big):  (64, 2048, 256)
Skipping (M, d_small, d_big):  (64, 2048, 512)
Skipping (M, d_small, d_big):  (64, 2048, 1024)
Index exists:  case_study_decoupled/MRL_D2048+IVF1024,PQ64_big_quantizer.faiss
Applying OPQ: 1K_opq_64m_d2048
Initialized big quantizer D2048, small quantizer D2048, M64, ncell1024


In [None]:
# Preassigned Search using multiple cores

USE_MULTITHREAD_SEARCH = True
num_cores = multiprocessing.cpu_count()
thread_batch_size = 1000

# Helper function to split search on multiple cores
def multisearch_preassigned(index, queryset, Ic, batch_iter):
    _, I = search_preassigned(index, 
                              queryset[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], 
                              100, # Shortlist length
                              Ic[thread_batch_size*batch_iter:thread_batch_size*(batch_iter+1)], 
                              None)
    return I

In [49]:
def search_adanns_indices(D_search, D_construct, n_cell, nprobes=[1]):
    queryset = np.ascontiguousarray(xq[:,:D_construct], dtype=np.float32)
    faiss.normalize_L2(queryset)
    
    queryset_small = np.ascontiguousarray(xq[:, :D_search], dtype=np.float32)
    faiss.normalize_L2(queryset_small)
    
    for M in [64]:
        top1 = [n_cell, D_construct, D_search, M]
        times = [n_cell, D_construct, D_search, M]
        if M > D_search or D_search > D_construct:
                continue
                
        # print(f'MRL IVF{n_cell},PQ{M}: D{D_search} search with D{D_construct} coarse quantization')
        quantizer_big = faiss.read_index(adanns_root + f'MRL_D{D_construct}+IVF{n_cell},PQ{M}_big_quantizer.faiss')
        index_ivf_small = faiss.read_index(adanns_root + f'MRL_Dsmall{D_search}_Dbig{D_construct}+IVF{n_cell},OPQ{M}_small_quantizer.faiss')
        
        # disable precomputed tables, because the Dc is out of sync with the 
        # small coarse quantizer
        index_ivf_small.use_precomputed_table = -1
        index_ivf_small.precompute_table()

        for nprobe in nprobes:
            start = time.time()

            # coarse quantization 
            _, Ic = quantizer_big.search(queryset, nprobe) # Ic: (50K, nprobe)

            # actual search 
            index_ivf_small.nprobe = nprobe
            
            if USE_MULTITHREAD_SEARCH:
                pool = ThreadPool(num_cores)
                partial_func = partial(multisearch_preassigned, index=index_ivf_small, queryset=queryset_small, Ic=Ic)
                I = pool.map(partial_func, range(queryset_small.shape[0] // thread_batch_size)) # 50K queries split to (num_batches, thread_batch_size) batches
                pool.close()
                pool.join()
                
            else:
                _, I = search_preassigned(index_ivf_small, queryset_small, 100, Ic, None) # I: (50K, 100)

            top1.append((np.sum(db_labels[I[:, 0]] == query_labels)) / query_labels.shape[0])
            times.append(time.time()-start)
            
        if (len(top1) > 4): # ignore continued cases
            with open('adanns-faiss-top1-opq.csv', 'a', encoding='UTF8', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(top1)
            with open('adanns-faiss-timing-opq.csv', 'a', encoding='UTF8', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(times)
            print(top1)
            # print(times)

## Metric Computation

In [50]:
header = ["n_cell", "D_construct", "D_search", "M", "1probe", "4probe", "8probe"]
print(header)

with open('adanns-faiss-top1-opq.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)
    
with open('adanns-faiss-timing-opq.csv', 'w', encoding='UTF8', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)
    
for D_construct in [64, 128, 256, 512, 1024, 2048]:
    for D_search in [64, 128, 256, 512, 1024, 2048]:
            search_adanns_indices(D_search, D_construct, n_cell=1024, nprobes=[1])

['n_cell', 'D_big', 'D_small', 'M', '1probe', '4probe', '8probe']
[1024, 64, 64, 64, 0.6942]
[1024, 128, 64, 64, 0.69422]
[1024, 128, 128, 64, 0.69584]
[1024, 256, 64, 64, 0.69334]
[1024, 256, 128, 64, 0.69604]
[1024, 256, 256, 64, 0.69632]
[1024, 512, 64, 64, 0.69418]
[1024, 512, 128, 64, 0.69676]
[1024, 512, 256, 64, 0.69568]
[1024, 512, 512, 64, 0.6969]
[1024, 1024, 64, 64, 0.69576]
[1024, 1024, 128, 64, 0.69716]
[1024, 1024, 256, 64, 0.69676]
[1024, 1024, 512, 64, 0.69648]
[1024, 1024, 1024, 64, 0.69412]
[1024, 2048, 64, 64, 0.69444]
[1024, 2048, 128, 64, 0.69608]
[1024, 2048, 256, 64, 0.6973]
[1024, 2048, 512, 64, 0.69628]
[1024, 2048, 1024, 64, 0.69274]
[1024, 2048, 2048, 64, 0.6899]
