In [1]:
import numpy as np
import faiss
import time
import pandas as pd
import matplotlib.pyplot as plt
import csv
from os import path, makedirs

In [2]:
root = 'path/to/embeddings/'

D = 2048
D_rr_search = 16 # to load high-D database and queryset for AR with rr models

method = 'adanns' # adanns, mg-ivf-rr, mg-ivf-svd
dataset = '1K' # 1K, 4K, V2

In [4]:
def load_data_helper(config, D_load=2048):
    db_csv = dataset + '_train_' + config + '-X.npy'
    query_csv = dataset + '_val_' + config + '-X.npy'
    db_label_csv = dataset + '_train_' + config + '-y.npy'
    query_label_csv = dataset + '_val_' + config + '-y.npy'
    
    if dataset == 'V2':
        db_csv = "1K_train_" + config + '-X.npy'
        db_label_csv = "1K_train_" + config + '-y.npy'

    db_load = np.ascontiguousarray(np.load(root+db_csv)[:, :D_load], dtype=np.float32)
    qy_load = np.ascontiguousarray(np.load(root+query_csv)[:, :D_load], dtype=np.float32)
    db_labels = np.load(root+db_label_csv)
    query_labels = np.load(root+query_label_csv)

    faiss.normalize_L2(db_load)
    faiss.normalize_L2(qy_load)

    return db_load, qy_load, db_labels, query_labels


def load_construct_data(D_construct, D_rr_svd, ncentroids):
    if method == 'adanns':
        config = f'mrl1_e0_ff{D_construct}'
    elif method == 'mg-ivf-rr':
        config = f'mrl0_e0_ff{D_construct}'
    elif method == 'mg-ivf-svd':
        config = f'mrl0_e0_rr{D_construct}_svd{D_rr_svd}'
    else:
        raise Exception("Unsupported ANNS method.")
    db_construct, qy_construct, db_labels, query_labels = load_data_helper(config, D_construct)
        
    print("Cluster Contruction DB: ", db_construct.shape)
    print("Cluster Construction queries:", qy_construct.shape)
    
    # Load kmeans index and centroids with shape (centroid, D_construct)
    size = str(ncentroids)+'ncentroid_'+str(D_construct)+'Dc'
    if dataset == 'V2': # V2 is only a test set, change to 1K
        dataset = '1K'
    index_file = root+'index_files/'+method+dataset+'_kmeans_'+size

    centroids_path = root+'kmeans/'+method+'ncentroids'+str(ncentroids)+"_"+str(D_construct)+'Dc_'+dataset+'.npy'
    centroids = np.load(centroids_path)
    print("Loaded centroids: ", centroids.shape, centroids_path)
    
    return db_construct, qy_construct, db_labels, query_labels, centroids, index_file

def load_search_data(D_search):
    if method == 'adanns':
        config = f'mrl1_e0_ff{D_search}'
    elif method in ['mg-ivf-rr', 'mg-ivf-svd']:
        config = f'mrl0_e0_ff{D_search}'
    else:
        raise Exception("Unsupported ANNS method.")
    db_search, qy_search, _ , _ = load_data_helper(config, D_search)

    return db_search, qy_search


In [5]:
def eval_cluster(val_classes, db_classes, neighbors, k): 
    APs, topk, recall = [], [], []
    cluster_size = neighbors.shape[0]
    for i in range(cluster_size):
        target = val_classes[i]
        indices = neighbors[i][:k] # k neighbor list for ith val vector
        labels = db_classes[indices]
        matches = (labels == target)
        
        # topk
        hits = np.sum(matches)
        if hits>0:
            topk.append(1)
        else:
            topk.append(0)
        
        # recall
        recall.append(np.sum(matches)/1300)
        
        # precision values
        tps = np.cumsum(matches)
        precs = tps.astype(float) / np.arange(1, k + 1, 1)
        APs.append(np.sum(precs[matches.squeeze()]) / k)
        
    return np.mean(recall), np.mean(topk), np.mean(APs)

In [6]:
def get_closest_centroids(queries, centroids, D_shortlist):
    centroid_index = faiss.IndexFlatL2(D_shortlist)
    xq_shortlist = np.ascontiguousarray(queries[:, :D_shortlist], dtype=np.float32)
    xc_shortlist = np.ascontiguousarray(centroids[:, :D_shortlist], dtype=np.float32)
    faiss.normalize_L2(xq_shortlist)
    faiss.normalize_L2(xc_shortlist)
    
    centroid_index.add(xc_shortlist)
    _, I = centroid_index.search(xq_shortlist, 1)

    return I

In [6]:
D_search_list = [8, 16, 32, 64, 128, 256, 512, 1024]
ncentroids = 1024

for D in [2048]:
    k=100
    D_rr_svd = D
    D_construct_list = [D]
    D_shortlist_list = [D]

    header = ['d_construct', 'd_search', 'd_shortlist', 'ncentroid', 'top1', 'recall@'+str(k), 'mAP@'+str(k)]
    print(header)
    with open('kmeans_metrics.csv', 'w', encoding='UTF8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(header)

    start = time.time()
    for D_c in D_construct_list:
        # Load all construction data (database, queries, centroids)
        xb_construct, xq_construct, db_labels, query_labels, centroids, index_file = load_construct_data(D_c, D_rr_svd, ncentroids)

        cpu_index = faiss.read_index(index_file+'.index')
        index = faiss.index_cpu_to_all_gpus(cpu_index)
        print("\nLoaded kmeans index:", index_file.split("/")[-1])

        # construct lookup table of centroid --> vectors, i.e. inverted lists
        _, I_db = index.search(xb_construct, 1)
        lut_db = {}
        for c in np.unique(I_db):
            lut_db[c] = np.argwhere(I_db==c)[:,0]

        for D_search in D_search_list:
            print("Linear scan with D_s = ", D_search)
            xb_search, xq_search = load_search_data(D_search)

            for D_shortlist in D_shortlist_list:
                # Currently, D_shortlist <= D_search is supported as we slice centroids for adanns
                I_q = get_closest_centroids(xq_search, centroids, D_shortlist)
                lut_q = {}

                start = time.time()
                recall, topk, mAP = [], [], []

                #Iterate over all centroids assigned to each
                for c in np.unique(I_q):
                    lut_q[c] = np.argwhere(I_q==c)[:,0]
                    exact_cpu_index = faiss.IndexFlatL2(D_search)

                    # add cluster vectors to index and search only queries that map to that cluster
                    exact = faiss.index_cpu_to_all_gpus(exact_cpu_index)
                    cluster_db = np.ascontiguousarray(xb_search[lut_db[c]][:, :D_search], np.float32)
                    cluster_query = np.ascontiguousarray(xq_search[lut_q[c]][:, :D_search], np.float32)
                    faiss.normalize_L2(cluster_db)
                    faiss.normalize_L2(cluster_query)
                    exact.add(cluster_db)
                    Dist, Ind = exact.search(cluster_query, k)

                    # replace cluster-specific indices with original database indices for eval
                    cluster_db_labels = db_labels[lut_db[c]]
                    cluster_query_labels = query_labels[lut_q[c]]

                    nn_1 = Ind[:, 0]
                    pred_1 = cluster_db_labels[nn_1]
                    hits = np.sum(pred_1 == cluster_query_labels)
                    topk.append(hits)

                    rl, tk, mp = eval_cluster(cluster_query_labels, cluster_db_labels, Ind, k)
                    recall.append(rl)
                    mAP.append(mp)
                row = [D_c, D_search, D_shortlist, ncentroids, np.sum(topk)/xq_search.shape[0], np.mean(recall), np.mean(mAP)]
                print(row)

                with open('kmeans_metrics.csv', 'a', encoding='UTF8', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(row)
                print("d_c:%d, d_s: %d, ncentroid: %d" %(D_c, D_search, ncentroids))
                print("Recall@100: ", np.mean(recall))
                print("Top1: ", np.sum(topk)/xq_search.shape[0])

    print("Total Time for %d configs = %f" % (len(D_search_list) * len(ncentroids) * len(D_construct_list), time.time() - start))

['d_c', 'd_s', 'd_shortlist', 'ncentroid', 'top1', 'recall@100', 'mAP@100', 'overlap']
Cluster Contruction DB:  (1281167, 2048)
Cluster Construction queries: (10000, 2048)

Loaded kmeans index: 1K_kmeans_1024ncentroid_2048d
Loaded centroids:  (1024, 2048) ../../inference_array/resnet50/kmeans/mrl/ncentroids1024_2048d_1K.npy
Linear scan with d =  8
[2048, 8, 2048, 1024, 0.5351, 0.05091163624126155, 0.605462860936279, 0]
d_c:2048, d_s: 8, ncentroid: 1024
Recall@100:  0.05091163624126155
Top1:  0.5351
Linear scan with d =  16
[2048, 32, 2048, 1024, 0.5732, 0.051922105632729296, 0.6227463073287828, 0]
d_c:2048, d_s: 32, ncentroid: 1024
Recall@100:  0.051922105632729296
Top1:  0.5732
Linear scan with d =  64
[2048, 64, 2048, 1024, 0.5785, 0.05198921759119408, 0.6243867377357402, 0]
d_c:2048, d_s: 64, ncentroid: 1024
Recall@100:  0.05198921759119408
Top1:  0.5785
Linear scan with d =  128
[2048, 128, 2048, 1024, 0.5802, 0.05199322657824928, 0.6247377244191619, 0]
d_c:2048, d_s: 128, ncentroi