In [1]:
import numpy as np
import faiss
import torch
import sys
sys.path.append('../')
from utils.py import load_embeddings

In [2]:
root = '../../../inference_array/resnet50/'
model = 'mrl' # mrl, ff
dataset = '1K' # 1K, 4K, V2
index_type = 'kmeans'
d = 2048 # cluster construction dim

_, queryset, _, _, _, _ = load_embeddings(model, dataset, d)
faiss.normalize_L2(queryset)
print("Loaded queries:", queryset.shape)

Loaded queries: (50000, 2048)


In [3]:
search_dim = [8, 16, 32, 64, 128, 256, 512, 1024, 2048]
nprobes = [1, 2, 5, 10]
ncentroids = [1024]

for centroid in ncentroids:
    print("Clusters: ", centroid)
    
    # Load kmeans index
    size = str(centroid)+'ncentroid_'+str(d)+'d'
    index_file = root+'index_files/'+model+dataset+'_'+index_type+'_'+size+"_nbits8_nlist2048"
    cpu_index = faiss.read_index(index_file+'.index')
    if torch.cuda.device_count() > 0:
        index = faiss.index_cpu_to_all_gpus(cpu_index)
    
    # Load and normalize centroids
    centroids_path = root+'kmeans/'+model+'ncentroids'+str(centroid)+"_"+str(d)+'d'"_"+dataset+'.npy'
    centroids = np.load(centroids_path)
    faiss.normalize_L2(centroids)
    gt = np.argsort(-queryset @ centroids.T, axis=1)
    
    topK = [1, 2, 4, 5, 10]
        
    for nprobe in nprobes:
        print("\nNumber of probes:", nprobe)
        print([f'top{k}' for k in topK])
        for dim in search_dim:
            q = np.ascontiguousarray(queryset[:, :dim])
            nqueries = q.shape[0]
            faiss.normalize_L2(q)
            c = np.ascontiguousarray(centroids[:, :dim])
            faiss.normalize_L2(c)
            low_d_clusters = np.argsort(-q @ c.T, axis=1)
            
            count = [0, 0, 0, 0, 0]
            
            # Iterate over all queries
            for i in range(nqueries):
                label = gt[i][0]
                target = low_d_clusters[i][:nprobe]
                for j in range(len(topK)):
                    count[j] += label in target[:topK[j]] # increments count[j] if correct

            print(np.array(count) / nqueries)

Clusters:  1024

Number of probes: 1
['top1', 'top2', 'top4', 'top5', 'top10']
[0.7355 0.7355 0.7355 0.7355 0.7355]
[0.84186 0.84186 0.84186 0.84186 0.84186]
[0.89878 0.89878 0.89878 0.89878 0.89878]
[0.93288 0.93288 0.93288 0.93288 0.93288]
[0.95518 0.95518 0.95518 0.95518 0.95518]
[0.96946 0.96946 0.96946 0.96946 0.96946]
[0.9822 0.9822 0.9822 0.9822 0.9822]
[0.99078 0.99078 0.99078 0.99078 0.99078]
[1. 1. 1. 1. 1.]

Number of probes: 2
['top1', 'top2', 'top4', 'top5', 'top10']
[0.7355  0.83386 0.83386 0.83386 0.83386]
[0.84186 0.92572 0.92572 0.92572 0.92572]
[0.89878 0.96662 0.96662 0.96662 0.96662]
[0.93288 0.98402 0.98402 0.98402 0.98402]
[0.95518 0.99246 0.99246 0.99246 0.99246]
[0.96946 0.99674 0.99674 0.99674 0.99674]
[0.9822  0.99894 0.99894 0.99894 0.99894]
[0.99078 0.99982 0.99982 0.99982 0.99982]
[1. 1. 1. 1. 1.]

Number of probes: 5
['top1', 'top2', 'top4', 'top5', 'top10']
[0.7355  0.83386 0.89734 0.91348 0.91348]
[0.84186 0.92572 0.96844 0.97612 0.97612]
[0.89878 0.9666