In [33]:
import hnswlib
import numpy as np 
import struct
import heapq
import time
import pickle
import os

from sklearn.cluster import KMeans

from distributed_graph_index_construction import HNSW_index

In [3]:
def load_obj(dirc, name):
    with open(os.path.join(dirc, name + '.pkl'), 'rb') as f:
        return pickle.load(f)

In [74]:
def recall_eval(result_list, gt, k):
    """
    Input:
        result list: a 2-dim list
            dim 1: query num
            dim 2: topK
        gt: a ground truth 2-d numpy array
            dim 1: query num
            dim 2: topK, 1000 for sift dataset
        k: topK to be used for recall evaluation,
            *** can be anything smaller than the dim2 of result_list ***)
    Output:
        recall
    """

    count = 0
    for i in range(query_num):
        gt_set = set()
        for j in range(k):
            gt_set.add(gt[i][j])
        for j in range(k):
            vec_ID = result_list[i][j]
            if vec_ID in gt_set:
                count += 1
    recall = count / (query_num * k)
    return recall

In [5]:
def mmap_fvecs(fname):
    x = np.memmap(fname, dtype='int32', mode='r')
    d = x[0]
    return x.view('float32').reshape(-1, d + 1)[:, 1:]

def mmap_bvecs(fname):
    x = np.memmap(fname, dtype='uint8', mode='r')
    d = x[:4].view('int32')[0]
    return x.reshape(-1, d + 4)[:, 4:]

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    # Wenqi: Format of ground truth (for 10000 query vectors):
    #   1000(topK), [1000 ids]
    #   1000(topK), [1000 ids]
    #        ...     ...
    #   1000(topK), [1000 ids]
    # 10000 rows in total, 10000 * 1001 elements, 10000 * 1001 * 4 bytes
    return a.reshape(-1, d + 1)[:, 1:].copy()

def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

In [6]:
dbname = 'SIFT1M'
index_path='../indexes/{}_index.bin'.format(dbname)
dim=128

if dbname.startswith('SIFT'):
    # SIFT1M to SIFT1000M
    dbsize = int(dbname[4:-1])
    xb = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_base.bvecs')
    xq = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_query.bvecs')
    gt = ivecs_read('/mnt/scratch/wenqi/Faiss_experiments/bigann/gnd/idx_%dM.ivecs' % dbsize)

    N_VEC = int(dbsize * 1000 * 1000)

    # trim xb to correct size
    xb = xb[:dbsize * 1000 * 1000]

    # Wenqi: load xq to main memory and reshape
    xq = xq.astype('float32').copy()
    xq = np.array(xq, dtype=np.float32)
    gt = np.array(gt, dtype=np.int32)

    print("Vector shapes:")
    print("Base vector xb: ", xb.shape)
    print("Query vector xq: ", xq.shape)
    print("Ground truth gt: ", gt.shape)
else:
    print('unknown dataset', dbname, file=sys.stderr)
    sys.exit(1)

Vector shapes:
Base vector xb:  (1000000, 128)
Query vector xq:  (10000, 128)
Ground truth gt:  (10000, 1000)


In [14]:
def distributed_search(query_vec, kmeans, index_list, k, ef, all_vectors):
    """
    query_vec: a numpy array of a single d-dimensional vector
    kmeans: the kmeans object
    index_list: a list of loaded python object of hnsw index
    """
    query_kmeans_format = query_vec.reshape(1,-1).astype(np.float64)
    partition_id = kmeans.predict(query_kmeans_format)[0]
    search_path = []
    all_results = set() # deduplicate results
    
    while True:
        current_index = index_list[partition_id]
        search_path.append(partition_id)
        
        results, local_results, remote_results, search_remote, remote_partition_id = \
            current_index.searchKnnPlusRemoteCache(query_vec, k, ef, all_vectors, debug=False)
        for r in results:
            all_results.add(r)
        if not search_remote:
            break
        else:
            if remote_partition_id in search_path:
                break
            else:
                partition_id = remote_partition_id
    
    # merge all results
    results_heap = []
    for dist, server_ID, vec_ID in all_results:
        heapq.heappush(results_heap, (-dist, server_ID, vec_ID))
    while len(results_heap) > k:
        heapq.heappop(results_heap)

    results = []
    while len(results_heap) > 0:
        dist, server_ID, vec_ID = results_heap[0]
        results.append((-dist, server_ID, vec_ID))
        heapq.heappop(results_heap)
    results.reverse()
    
    # this is in descending order for distance
    return results, search_path

## Load Index

In [8]:
N_SUBGRAPH = 32
parent_dir = '../indexes_subgraph_kmeans/SIFT1M_{}_subgraphs'.format(N_SUBGRAPH)

In [9]:
all_hnsw_indexes = [load_obj(parent_dir, 'subgraph_{}_with_remote_edges'.format(i)) for i in range(N_SUBGRAPH)]

In [10]:
for i in range(N_SUBGRAPH): 
    print(len(all_hnsw_indexes[i].remote_links))

27778
29567
33003
29067
27136
32485
30647
29204
28404
29919
32778
32829
32178
29858
34848
40082
30120
26974
31261
27816
31228
27589
28360
26756
43264
34713
33191
34977
27934
31373
32349
32312


In [11]:
kmeans = load_obj(parent_dir, 'kmeans')

## Recall on distributed graph search

In [26]:
result_list = []
search_path_list = []
query_num = 10000
k = 100

for i in range(query_num):
    results, search_path = distributed_search(
        xq[i], kmeans, index_list=all_hnsw_indexes, k=k, ef=128, all_vectors=xb)
    result_list.append(results)
    search_path_list.append(search_path)

In [27]:
## Get recall for consider up to 1 remote hop
## Wenqi comment: for k-means-based method, the recall is really high
# First 100 queries -> 1.0 recall
# First 1000 queries -> 0.994 recall
# First 10000 queries -> 0.9916 recall 

result_list_I = []
for i in range(len(result_list[i])):

print("R@1 =", recall_eval(result_list=result_list, gt=gt, k=1))
print("R@10 =", recall_eval(result_list=result_list, gt=gt, k=10))
print("R@100 =", recall_eval(result_list=result_list, gt=gt, k=100))

R@1 = 0.9916
R@10 = 0.98342
R@100 = 0.954067


In [None]:
# Count how many searches travel to remote node
# First 100 queries -> 31% travels to remote node; average search path length = 1.31
# First 100 queries -> 30% travels to remote node; average search path length = 1.304
# First 100 queries -> 29.91% travels to remote node; average search path length = 1.3054 （1 case travel to a third server）

search_remote_count = 0
total_path_length = 0
path_len = np.array([len(search_path_list[i]) for i in range(len(search_path_list))])
len_count = dict() 

for i in range(query_num):
    total_path_length += path_len[i]
    if path_len[i] in len_count:
        len_count[path_len[i]] += 1
    else:
        len_count[path_len[i]] = 1
    if len(search_path_list[i]) > 1: search_remote_count += 1
        
average_path_length = total_path_length / query_num
print("search remote rate: {} ({} cases)".format(search_remote_count/query_num, search_remote_count))
print("average path length: {}".format(average_path_length))
print("search path length distribution: {}".format(len_count))

## Partition-based search

Without distributed search. Using K-means to decide m partitions to search. Explore the relationship between m and recall.

In [57]:
""" Load hnswlib index """
all_server_IDs = np.arange(N_SUBGRAPH)
all_hnswlib_indexes = [hnswlib.Index(space='l2', dim=dim) for i in all_server_IDs]
parent_dir = '../indexes_subgraph_kmeans/SIFT1M_32_subgraphs'
all_index_paths=[os.path.join(parent_dir, 'subgraph_{}.bin'.format(i)) for i in all_server_IDs]
for i in all_server_IDs:
    print("\nLoading hnswlib index from {}\n".format(all_index_paths[i]))
    all_hnswlib_indexes[i].load_index(all_index_paths[i])


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_0.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_1.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_2.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_3.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_4.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_5.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_6.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_7.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_8.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_9.bin


Loading hnswlib index from ../indexes_subgraph_kmeans/SIFT1M_32_subgraphs/subgraph_10.bin

In [34]:
cluster_centers = kmeans.cluster_centers_
print(cluster_centers.shape, cluster_centers)

(32, 128) [[16.65008455 53.98379369 83.4343292  ... 16.89162909 18.9815389
  19.68714769]
 [10.67950214 10.82730455 12.91391158 ...  5.00518605  5.97445871
   9.6548684 ]
 [12.65382377  9.79548365 12.54801536 ... 12.06006286 13.3032243
  15.60807822]
 ...
 [57.04531152 19.92309629  7.68596602 ...  6.68609188  6.42706104
   8.05739459]
 [35.69270833 14.90234375  9.1859375  ... 22.42903646  6.80299479
   8.53033854]
 [27.29260682 28.88550168 23.75180029 ...  7.3731397  12.95079213
  14.23847816]]


In [45]:
def compute_centroid_distances(cluster_centers, query_vecs):
    """
    Input:
        cluster_centers: 2-d array (num_clusters, dim)
        query_vecs: 2-d array (num_queries, dim)
    Output:
        distance_mat (num_queries, num_clusters),
            each element is a distance (L2 square)
    """
    num_clusters, dim = cluster_centers.shape
    nq = query_vecs.shape[0]
    assert dim == query_vecs.shape[1]
    
    distance_mat = np.zeros((nq, num_clusters))
    
    for i in range(num_clusters):
        centroid_replications = np.tile(cluster_centers[i], (nq,1))
        distance_mat[:, i] = np.sum((query_vecs - centroid_replications) ** 2, axis=1)
    
    return distance_mat

def kmeans_predict_sorted(cluster_centers, query_vecs):
    """
    Compute the cell centroid IDs for each query in a sorted manner 
        (increasing distance)
    
    Input:
        cluster_centers: 2-d array (num_clusters, dim)
        query_vecs: 2-d array (num_queries, dim)
    Output:
        ID_mat (num_queries, num_clusters),
            each element is a centroid ID 
    """
    num_clusters, dim = cluster_centers.shape
    nq = query_vecs.shape[0]
    
    distance_mat = compute_centroid_distances(cluster_centers, query_vecs)
    ID_mat = np.argsort(distance_mat, axis=1)
    
    return ID_mat

In [51]:
# Evaluate function correctness
print(kmeans_predict_sorted(cluster_centers, xq[:10]))
# query_kmeans_format = query_vec.reshape(1,-1).astype(np.float64)
partition_id = kmeans.predict(xq[:10].astype(np.float64))
print(partition_id)

[[14 20 24  9  7 12  0 31 25  5 27 23 19  8 15 13 18 21  2 26 11 30  6 17
  28  4 29 22  1 16  3 10]
 [15 22 26  6 28 17 11 18 13  2 21  4  0 20 12 16 30  1 23  7 31  5  3 10
   8 14 19 24 27  9 29 25]
 [29 21 19 25 23  7 27  5  8  9 24 14 31 20  4  0 12 13 10 30  3 18  2 22
  28 17 26  1  6 11 15 16]
 [10 29  4 19  1  3 17  8  6 21 23  7  5 16 22 26 31 18 11 30 13 14 27 20
  25  2  9 28 12 15  0 24]
 [15 13 17 22 26 28  2 11  6 18 21  4 30 16  1 20 12  5  0  7  3 23 27 10
  19 31  8 24  9 14 29 25]
 [29 10  4 19 23  3  8 21 25  7  1 22  6  5 17 13 31 27 14 30  9 18 12 26
  11 16  0 20 28 15  2 24]
 [29 19 10  5 23 27  4  8 25  7  3 21  1 31  9 14 22 13 12 20 24  0 17  6
  26 18 30 11  2 16 28 15]
 [14 31 20  0 12 24 25 19 23  7 27  9 18  5 29  8  6 21 15 22  4  3 28 13
  26  1  2 30 11 17 10 16]
 [26 16 17  8 15 30 20 21 11 28 22  4 10  9  2  7 18 13  1  3  6 19  5  0
  29 31 12 23 27 14 24 25]
 [ 5  8 19 29 20  4  9 27  1 23 10  7 13  2 25 24 31 26 12  3 22 21 17 14
   6 11 18  0 28 

In [52]:
sorted_partition_IDs = kmeans_predict_sorted(cluster_centers, xq)

In [53]:
sorted_partition_IDs.shape

(10000, 32)

In [65]:
""" Search several partitions per query vector """
MAX_VISITED_PARTITIONS = 8 # explore at most 
nq = xq.shape[0]
dim = xq.shape[1]
k = 100

all_I = np.zeros((nq, k * MAX_VISITED_PARTITIONS), dtype=int)
all_D = np.zeros((nq, k * MAX_VISITED_PARTITIONS))

for index in all_hnswlib_indexes: 
    index.set_ef(128)

for vec_id in range(nq):
    if vec_id % 1000 == 0: print("query id: ", vec_id)
    for j in range(MAX_VISITED_PARTITIONS):
        index_id = sorted_partition_IDs[vec_id][j]
        all_I[vec_id][j * k: (j + 1) * k], all_D[vec_id, j * k: (j + 1) * k] = \
            all_hnswlib_indexes[index_id].knn_query(xq[vec_id], k=k)

query id:  0
query id:  1000
query id:  2000
query id:  3000
query id:  4000
query id:  5000
query id:  6000
query id:  7000
query id:  8000
query id:  9000


In [82]:
print(all_I.shape)

(10000, 800)


In [98]:
# Compute the recall combining {k} x {num_partitions (1~MAX_VISITED_PARTITIONS)}
for tmp_k in [1, 10, 100]:
    
    all_I_k_tmp = np.zeros((nq, tmp_k * MAX_VISITED_PARTITIONS), dtype=int)
    all_D_k_tmp = np.zeros((nq, tmp_k * MAX_VISITED_PARTITIONS))

    # copy tmp_k results from hnswlib
    for vec_id in range(nq):
        for j in range(MAX_VISITED_PARTITIONS):
            all_I_k_tmp[vec_id][j * tmp_k: (j + 1) * tmp_k] = all_I[vec_id][j * k: j * k + tmp_k]
            all_D_k_tmp[vec_id][j * tmp_k: (j + 1) * tmp_k] = all_D[vec_id][j * k: j * k + tmp_k]

    # for upto MAX_VISITED_PARTITIONS partition, compute recall
    for tmp_partition in range(1, 1 + MAX_VISITED_PARTITIONS):
        
        D_I_k_tmp = []
        for vec_id in range(nq):
            D_I_k_tmp.append([])
            for j in range(tmp_partition):
                for m in range(tmp_k):
                    D_I_k_tmp[vec_id].append((all_D_k_tmp[vec_id][j * tmp_k + m], all_I_k_tmp[vec_id][j * tmp_k + m]))

        D_k_tmp = []
        I_k_tmp = []
        for vec_id in range(nq):
            D_I_tmp = sorted(D_I_k_tmp[vec_id])[:tmp_k]
            D_k_tmp.append([D for D, I in D_I_tmp[:tmp_k]])
            I_k_tmp.append([I for D, I in D_I_tmp[:tmp_k]])
        
        print("Num partition = {}\tR@{} = {}".format(
            tmp_partition, tmp_k, recall_eval(result_list=I_k_tmp, gt=gt, k=tmp_k)))

Num partition = 1	R@1 = 0.6942
Num partition = 2	R@1 = 0.8716
Num partition = 3	R@1 = 0.9356
Num partition = 4	R@1 = 0.9659
Num partition = 5	R@1 = 0.9819
Num partition = 6	R@1 = 0.9891
Num partition = 7	R@1 = 0.9928
Num partition = 8	R@1 = 0.9959
Num partition = 1	R@10 = 0.66609
Num partition = 2	R@10 = 0.84712
Num partition = 3	R@10 = 0.91883
Num partition = 4	R@10 = 0.95313
Num partition = 5	R@10 = 0.97177
Num partition = 6	R@10 = 0.98224
Num partition = 7	R@10 = 0.98794
Num partition = 8	R@10 = 0.99173
Num partition = 1	R@100 = 0.612749
Num partition = 2	R@100 = 0.799239
Num partition = 3	R@100 = 0.880868
Num partition = 4	R@100 = 0.922392
Num partition = 5	R@100 = 0.946111
Num partition = 6	R@100 = 0.960622
Num partition = 7	R@100 = 0.969353
Num partition = 8	R@100 = 0.975071
