In [1]:
import pandas as pd
import numpy as np
import time
import os
import faiss
import csv

## Configuration Variables

In [2]:
D = 2048 # vector dim
ROOT_DIR = '../../inference_array/resnet50/'
CONFIG = 'mrl/' # ['mrl/', 'rr/']
NESTING = CONFIG == 'mrl/'
SEARCH_INDEX = 'ivfpq' # ['exactl2', 'ivfpq', 'opq', 'hnsw32']
DATASET = '1K' # 1K, V2, 4K

# Quantization Variables
nbits = 8 # nbits used to represent centroid id; total possible is k* = 2**nbits
nlist = 1024  # how many Voronoi cells (must be >= k*)
iterator = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] # vector dim D 

if SEARCH_INDEX in ['ivfpq', 'opq']:
    M = 32 # number of sub-quantizers, i.e. compression in bytes

In [3]:
def compute_mAP_recall_at_k(val_classes, db_classes, neighbors, k):
    """
    Computes the MAP@k on neighbors with val set by seeing if nearest neighbor
    is in the same class as the class of the val code. Let m be size of val set, and n in train.

      val:          (m x d) All the truncated vector representations of images in val set
      val_classes:  (m x 1) class index values for each vector in the val set
      db_classes:   (n x 1) class index values for each vector in the train set
      neighbors:    (m x k) indices in train set of top k neighbors for each vector in val set
    """

    """
    ImageNet-1K:
    shape of val is: (50000, dim)
    shape of val_classes is: (50000, 1)
    shape of db_classes is: (1281167, 1)
    shape of neighbors is: (50000, k)
    """
    APs, precision, recall, topk, unique_cls = [], [], [], [], []
    
    for i in range(val_classes.shape[0]): # Compute precision for each vector's list of k-nn
        target = val_classes[i]
        indices = neighbors[i, :][:k]     # k neighbor list for ith val vector
        labels = db_classes[indices]
        matches = (labels == target)
        
        # Number of unique classes
        unique_cls.append(len(np.unique(labels)))
        
        # topk
        hits = np.sum(matches)
        if hits > 0:
            topk.append(1)
        else:
            topk.append(0)
            
        # true positive counts
        tps = np.cumsum(matches)

        # recall
        recall.append(np.sum(matches)/1300)
        precision.append(np.sum(matches)/k)

        # precision values
        precs = tps.astype(float) / np.arange(1, k + 1, 1)
        APs.append(np.sum(precs[matches.squeeze()]) / k)

    return np.mean(APs), np.mean(precision), np.mean(recall), np.mean(topk), np.mean(unique_cls)

In [4]:
def get_k_recall_at_N(exact_gt, neighbors, k=40, N=2048):
    """
    Computes k-Recall@N which denotes the recall of k true nearest neighbors (exact search) 
    when N datapoints are retrieved with ANNS. Let q be size of query set.
    
      exact_gt:   (q x k) True nearest-neighbors of query set computed with exact search
      neighbors:  (q x N) Approximate nearest-neighbors of query set
      k:          (1)     Number of true nearest-neighbors
      N:          (1)     Number of approximate nearest-neighbors retrieved
    """
    labels = exact_gt[:, :k] # Labels from true NN
    targets = neighbors
    num_queries = exact_gt.shape[0]
    count = 0
    for i in range(num_queries):
        label = labels[i]
        target = targets[i, :N]
        # Compute overlap between approximate and true nearest-neighbors
        count += len(list(set(label).intersection(target)))
    return count / (num_queries * k)

## Load database, query, and neighbor arrays and compute metrics

In [5]:
def load_knn_array(dim, **kwargs):
    if SEARCH_INDEX in ['ivfpq', 'opq']:
        if (M > dim):
            return
        size = 'm'+str(M)+'_nlist'+str(nlist)+"_nprobe"+str(nprobe)+"_"
    elif SEARCH_INDEX == 'ivfsq':
        size = str(qtype)+'qtype_'
    elif SEARCH_INDEX == 'kmeans':
        size = str(nlist)+'ncentroid_'
    elif SEARCH_INDEX == 'ivf':
        size = 'nlist'+str(nlist)+"_nprobe"+str(nprobe)+"_"
    elif SEARCH_INDEX in ['hnsw32', 'hnswpq_M32_pq-m8','hnswpq_M32_pq-m16','hnswpq_M32_pq-m32','hnswpq_M32_pq-m64', 'hnswpq_M32_pq-m128']:
        size = 'efsearch'+str(nprobe)+"_"
    else:
        raise Exception(f"Unsupported Search Index: {SEARCH_INDEX}")

    # Load neighbors array and compute metrics
    neighbors_path = ROOT_DIR + "neighbors/" + CONFIG + SEARCH_INDEX+"/"+SEARCH_INDEX + "_" + size \
                + "2048shortlist_" + DATASET + "_d"+str(dim)+".csv"
    
    if not os.path.exists(neighbors_path):
        print(neighbors_path.split("/")[-1] + " not found")
        return

    return pd.read_csv(neighbors_path, header=None).to_numpy()


def print_metrics(iterator, shortlist, metric, nprobe=1, N=2048):
    """
    Computes and print retrieval metrics.
    
      iterator:   (List) True nearest-neighbors of query set computed with exact search
      shortlist:  (List) Number of data points retrieved (k)
      metric:     Name of metric ['topk', 'mAP', 'precision', 'recall', 'unique_cls', 'k_recall_at_n']
      nprobe:     Number of clusters probed during search (IVF) OR 'efSearch' for HNSW search quality
      N:          Number of data points retrieved for k-recall@N
    """
    # Load database and query set for nested models
    if NESTING:
        # Database: 1.2M x 1 for Imagenet-1K
        if DATASET == 'V2':
            db_labels = np.load(ROOT_DIR + "1K_train_mrl1_e0_ff2048-y.npy")
        else:
            db_labels = np.load(ROOT_DIR + DATASET + "_train_mrl1_e0_ff2048-y.npy")
        
        # Query set: 50K x 1 for Imagenet-1K
        query_labels = np.load(ROOT_DIR + DATASET + "_val_mrl1_e0_ff2048-y.npy")
    
    for dim in iterator:
        # Load database and query set for fixed feature models
        if not NESTING:
            db_labels = np.load(ROOT_DIR + DATASET + "_train_mrl0_e0_ff"+str(dim)+"-y.npy")
            query_labels = np.load(ROOT_DIR + DATASET + "_val_mrl0_e0_ff"+str(D)+"-y.npy")
            
        neighbors = load_knn_array(dim, M=M, nlist=nlist, nprobe=nprobe)
        
        for k in shortlist:
            if metric == 'k_recall_at_n':
                # Use 40-NN from Exact Search with MRL as GT
                if NESTING:
                    query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/mrl_exactl2_2048dim_{k}shortlist_1K.csv', header=None).to_numpy()
                else:
                    query_labels = pd.read_csv(ROOT_DIR + f'k-recall@N_ground_truth/rr_exactl2_{dim}dim_{k}shortlist_1K.csv', header=None).to_numpy()
                
                k_recall = (get_k_recall_at_N(query_labels, neighbors, k, N))
                print(f'{k}-Recall@{N} = {k_recall}')
                
            else:
                mAP, precision, recall, topk, unique_cls = compute_mAP_recall_at_k(query_labels, db_labels, neighbors, k)
                if (metric == 'topk'): print(f'topk, {dim}, {M}, {nprobe}, {topk}')
                elif (metric == 'mAP'): print(f'mAP, {dim}, {M}, {nprobe}, {mAP}')
                elif (metric == 'precision'): print(f'precision, {dim}, {M}, {nprobe}, {precision}')
                elif (metric == 'recall') : print(f'recall, {dim}, {M}, {nprobe}, {recall}')
                elif (metric == 'unique_cls'): print(f'unique_cls, {dim}, {M}, {nprobe}, {unique_cls}')
                else: raise Exception("Unsupported metric!")

## Example: Traditional Retrieval Metrics (Top-1, mAP, Recall)

In [6]:
# Example evaluation for IVFPQ
iterator = [16, 32]
print("Index:", SEARCH_INDEX)
print("metric, D, M, nprobe, value")
for M in [8]:
    for nprobe in [1]:
            print_metrics(iterator, [1], 'topk', nprobe)
            print_metrics(iterator, [10], 'mAP', nprobe)
            print_metrics(iterator, [100], 'recall', nprobe)

Index: ivfpq
metric, D, M, nprobe, value
topk, 16, 8, 1, 0.6775
topk, 32, 8, 1, 0.6861
mAP, 16, 8, 1, 0.6306868079365078
mAP, 32, 8, 1, 0.6374524079365079
recall, 16, 8, 1, 0.05151807692307692
recall, 32, 8, 1, 0.051838800000000004


## Example: ANNS Metric: k-Recall@N

In [64]:
USE_K_RECALL_AT_N = True
SEARCH_INDEX = 'hnsw32'
iterator = [8, 16, 32, 64]

print_metrics(iterator, [40], 'krecall', nprobe=1, N=2048)

k-recall@N GT:  (50000, 40)
40-Recall@2048 = 0.2071915
k-recall@N GT:  (50000, 40)
40-Recall@2048 = 0.311641
k-recall@N GT:  (50000, 40)
40-Recall@2048 = 0.377283
k-recall@N GT:  (50000, 40)
40-Recall@2048 = 0.4137225
