In [2]:
import hnswlib
import numpy as np 
import struct
import heapq
import time
import pickle
import os

from sklearn.cluster import KMeans

from distributed_graph_index_construction import HNSW_index

In [4]:
def load_obj(dirc, name):
    with open(os.path.join(dirc, name + '.pkl'), 'rb') as f:
        return pickle.load(f)

In [19]:
def mmap_fvecs(fname):
    x = np.memmap(fname, dtype='int32', mode='r')
    d = x[0]
    return x.view('float32').reshape(-1, d + 1)[:, 1:]

def mmap_bvecs(fname):
    x = np.memmap(fname, dtype='uint8', mode='r')
    d = x[:4].view('int32')[0]
    return x.reshape(-1, d + 4)[:, 4:]

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    # Wenqi: Format of ground truth (for 10000 query vectors):
    #   1000(topK), [1000 ids]
    #   1000(topK), [1000 ids]
    #        ...     ...
    #   1000(topK), [1000 ids]
    # 10000 rows in total, 10000 * 1001 elements, 10000 * 1001 * 4 bytes
    return a.reshape(-1, d + 1)[:, 1:].copy()

def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

In [20]:
dbname = 'SIFT1M'
index_path='../indexes/{}_index.bin'.format(dbname)
dim=128

if dbname.startswith('SIFT'):
    # SIFT1M to SIFT1000M
    dbsize = int(dbname[4:-1])
    xb = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_base.bvecs')
    xq = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_query.bvecs')
    gt = ivecs_read('/mnt/scratch/wenqi/Faiss_experiments/bigann/gnd/idx_%dM.ivecs' % dbsize)

    N_VEC = int(dbsize * 1000 * 1000)

    # trim xb to correct size
    xb = xb[:dbsize * 1000 * 1000]

    # Wenqi: load xq to main memory and reshape
    xq = xq.astype('float32').copy()
    xq = np.array(xq, dtype=np.float32)
    gt = np.array(gt, dtype=np.int32)

    print("Vector shapes:")
    print("Base vector xb: ", xb.shape)
    print("Query vector xq: ", xq.shape)
    print("Ground truth gt: ", gt.shape)
else:
    print('unknown dataset', dbname, file=sys.stderr)
    sys.exit(1)

Vector shapes:
Base vector xb:  (1000000, 128)
Query vector xq:  (10000, 128)
Ground truth gt:  (10000, 1000)


In [31]:
def distributed_search(query_vec, kmeans, index_list, k, ef, all_vectors):
    """
    query_vec: a numpy array of a single d-dimensional vector
    kmeans: the kmeans object
    index_list: a list of loaded python object of hnsw index
    """
    query_kmeans_format = query_vec.reshape(1,-1).astype(np.float64)
    partition_id = kmeans.predict(query_kmeans_format)[0]
    search_path = []
    all_local_results = []
    
    while True:
        current_index = index_list[partition_id]
        search_path.append(partition_id)
        
        results, local_results, remote_results, search_remote, remote_partition_id = \
            current_index.searchKnnPlusRemoteCache(query_vec, k, ef, all_vectors, debug=False)
        all_local_results += local_results
        if not search_remote:
            break
        else:
            if remote_partition_id in search_path:
                break
            else:
                partition_id = remote_partition_id
                
    # merge all results
    results_heap = []
    for dist, server_ID, vec_ID in all_local_results:
        heapq.heappush(results_heap, (-dist, server_ID, vec_ID))
    while len(results_heap) > k:
        heapq.heappop(results_heap)

    results = []
    while len(results_heap) > 0:
        dist, server_ID, vec_ID = results_heap[0]
        results.append((-dist, server_ID, vec_ID))
        heapq.heappop(results_heap)
    results.reverse()
    
    # this is in descending order for distance
    return results, search_path

In [14]:
parent_dir = '../indexes_subgraph_kmeans/SIFT1M_4_subgraphs'
N_SUBGRAPH = 4

In [15]:
all_hnsw_indexes = [load_obj(parent_dir, 'subgraph_{}_with_remote_edges'.format(i)) for i in range(N_SUBGRAPH)]

In [16]:
for i in range(N_SUBGRAPH): 
    print(len(all_hnsw_indexes[i].remote_links))

358987
257516
215615
167882


In [17]:
kmeans = load_obj(parent_dir, 'kmeans')

In [45]:
result_list = []
search_path_list = []
query_num = 10000

for i in range(query_num):
    results, search_path = distributed_search(
        xq[i], kmeans, index_list=all_hnsw_indexes, k=1, ef=128, all_vectors=xb)
    result_list.append(results[-1])
    search_path_list.append(search_path)

In [46]:
## Get recall for consider up to 1 remote hop
## Wenqi comment: for k-means-based method, the recall is really high
# First 100 queries -> 1.0 recall
# First 1000 queries -> 0.996 recall
# First 10000 queries -> 0.996 recall (visit remote graph = 0.9978)

count = 0
for i in range(query_num):
    ID = result_list[i][2]
    if ID == gt[i][0]: count += 1
print(count/query_num, count)

0.996 9960


In [47]:
# Count how many searches travel to remote node
# First 100 queries -> 9% travels to remote node; average search path length = 1.09
# First 100 queries -> 11.8% travels to remote node; average search path length = 1.118
# First 100 queries -> 10.96% travels to remote node; average search path length = 1.1097

search_remote_count = 0
total_path_length = 0
for i in range(query_num):
    total_path_length += len(search_path_list[i])
    if len(search_path_list[i]) > 1: search_remote_count += 1
average_path_length = total_path_length / query_num
print("search remote rate: {} ({} cases)".format(search_remote_count/query_num, search_remote_count))
print("average path length: {}".format(average_path_length))

search remote rate: 0.1096 (1096 cases)
average path length: 1.1097
