In [1]:
import hnswlib
import numpy as np 
import struct
import heapq
import time
import pickle
import os

from sklearn.cluster import KMeans

from distributed_graph_index_construction import HNSW_index

In [2]:
def load_obj(dirc, name):
    with open(os.path.join(dirc, name + '.pkl'), 'rb') as f:
        return pickle.load(f)

In [3]:
def recall_eval(result_list, gt, k):
    """
    Input:
        result list: a 2-dim list
            dim 1: query num
            dim 2: topK
        gt: a ground truth 2-d numpy array
            dim 1: query num
            dim 2: topK, 1000 for sift dataset
        k: topK to be used for recall evaluation,
            *** can be anything smaller than the dim2 of result_list ***)
    Output:
        recall
    """

    count = 0
    for i in range(query_num):
        gt_set = set()
        for j in range(k):
            gt_set.add(gt[i][j])
        for j in range(k):
            vec_ID = result_list[i][j]
            if vec_ID in gt_set:
                count += 1
    recall = count / (query_num * k)
    return recall

In [4]:
def mmap_fvecs(fname):
    x = np.memmap(fname, dtype='int32', mode='r')
    d = x[0]
    return x.view('float32').reshape(-1, d + 1)[:, 1:]

def mmap_bvecs(fname):
    x = np.memmap(fname, dtype='uint8', mode='r')
    d = x[:4].view('int32')[0]
    return x.reshape(-1, d + 4)[:, 4:]

def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    # Wenqi: Format of ground truth (for 10000 query vectors):
    #   1000(topK), [1000 ids]
    #   1000(topK), [1000 ids]
    #        ...     ...
    #   1000(topK), [1000 ids]
    # 10000 rows in total, 10000 * 1001 elements, 10000 * 1001 * 4 bytes
    return a.reshape(-1, d + 1)[:, 1:].copy()

def fvecs_read(fname):
    return ivecs_read(fname).view('float32')

In [5]:
dbname = 'SIFT1M'
index_path='../indexes/{}_index.bin'.format(dbname)
dim=128

if dbname.startswith('SIFT'):
    # SIFT1M to SIFT1000M
    dbsize = int(dbname[4:-1])
    xb = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_base.bvecs')
    xq = mmap_bvecs('/mnt/scratch/wenqi/Faiss_experiments/bigann/bigann_query.bvecs')
    gt = ivecs_read('/mnt/scratch/wenqi/Faiss_experiments/bigann/gnd/idx_%dM.ivecs' % dbsize)

    N_VEC = int(dbsize * 1000 * 1000)

    # trim xb to correct size
    xb = xb[:dbsize * 1000 * 1000]

    # Wenqi: load xq to main memory and reshape
    xq = xq.astype('float32').copy()
    xq = np.array(xq, dtype=np.float32)
    gt = np.array(gt, dtype=np.int32)

    print("Vector shapes:")
    print("Base vector xb: ", xb.shape)
    print("Query vector xq: ", xq.shape)
    print("Ground truth gt: ", gt.shape)
else:
    print('unknown dataset', dbname, file=sys.stderr)
    sys.exit(1)

Vector shapes:
Base vector xb:  (1000000, 128)
Query vector xq:  (10000, 128)
Ground truth gt:  (10000, 1000)


In [6]:
def distributed_search(query_vec, kmeans, index_list, k, ef, all_vectors):
    """
    query_vec: a numpy array of a single d-dimensional vector
    kmeans: the kmeans object
    index_list: a list of loaded python object of hnsw index
    """
    assert ef >= k
    
    query_kmeans_format = query_vec.reshape(1,-1).astype(np.float64)
    partition_id = kmeans.predict(query_kmeans_format)[0]
    search_path = []
    graph_path_len_list = [] # the graph search path length of each visited servers
    entry_point_ID = None
    
    # While doing distributed search, set k as ef, such that the first server can 
    #   pass the entire result list to the second server.
    # After searching all required servres, we prune the result
    while True:
        current_index = index_list[partition_id]
        search_path.append(partition_id)
        
        if (len(search_path)) == 1: 
            # Search from the top layer of HNSW
            results, local_results, remote_results, search_remote, remote_partition_id, remote_ep_vec_ID, graph_path_len = \
                current_index.searchKnnPlusRemoteCache(query_vec, k=ef, ef=ef, all_vectors=all_vectors)
        else:
            # Search from the ground layer given the info passed by the last server
            results, local_results, remote_results, search_remote, remote_partition_id, remote_ep_vec_ID, graph_path_len = \
                current_index.searchKnnPlusRemoteCache(query_vec, k=ef, ef=ef, all_vectors=all_vectors, 
                                                       ep_vec_id=entry_point_ID, existing_results=results)
        
        entry_point_ID = remote_ep_vec_ID
        graph_path_len_list.append(graph_path_len)
        
        if not search_remote:
            break
        else:
            if remote_partition_id in search_path:
                break
            else:
                partition_id = remote_partition_id
    
    # merge all results
    results = results[:k]
    
    # this is in descending order for distance
    return results, search_path, graph_path_len_list

## Load Index

In [7]:
N_SUBGRAPH = 32
parent_dir = '../indexes_subgraph_kmeans/SIFT1M_{}_subgraphs'.format(N_SUBGRAPH)

In [8]:
all_hnsw_indexes = [load_obj(parent_dir, 'subgraph_{}_with_remote_edges'.format(i)) for i in range(N_SUBGRAPH)]

In [9]:
for i in range(N_SUBGRAPH): 
    print(len(all_hnsw_indexes[i].remote_links))

32085
34445
32463
28116
30155
27535
32671
37909
40761
32794
29259
29783
27649
29474
28696
30636
36666
29190
35601
27197
26314
40255
33383
27533
33262
31533
32801
32683
27502
29037
25241
27371


In [10]:
kmeans = load_obj(parent_dir, 'kmeans')

## Recall on distributed graph search

In [17]:
result_list = []
search_path_list = []
all_graph_path_len_list = []
query_num = 10000
k = 100

for i in range(query_num):
    results, search_path, graph_path_len_list = distributed_search(
        xq[i], kmeans, index_list=all_hnsw_indexes, k=k, ef=128, all_vectors=xb)
    result_list.append(results)
    search_path_list.append(search_path)
    all_graph_path_len_list.append(graph_path_len_list)

In [18]:
## Get recall for consider up to 1 remote hop
## Wenqi comment: for k-means-based method, the recall is really high
# First 100 queries -> 1.0 recall
# First 1000 queries -> 0.994 recall
# First 10000 queries -> 0.9916 recall 

result_list_I = [[] for _ in range(len(result_list))]
for i in range(len(result_list)): 
    for r in result_list[i]:
        result_list_I[i].append(r[2])

print("R@1 =", recall_eval(result_list=result_list_I, gt=gt, k=1))
print("R@10 =", recall_eval(result_list=result_list_I, gt=gt, k=10))
print("R@100 =", recall_eval(result_list=result_list_I, gt=gt, k=100))

R@1 = 0.9912
R@10 = 0.9861
R@100 = 0.962541


In [19]:
# Count how many searches travel to remote node
# First 100 queries -> 31% travels to remote node; average search path length = 1.31
# First 100 queries -> 30% travels to remote node; average search path length = 1.304
# First 100 queries -> 29.91% travels to remote node; average search path length = 1.3054 （1 case travel to a third server）

search_remote_count = 0
total_path_length = 0
path_len = np.array([len(search_path_list[i]) for i in range(len(search_path_list))])
len_count = dict() 
graph_path_len_server = dict() # key=server_path_len; value = (total_graph_path_len, case_count), e.g., 
# The first traversed server: total graph path len, number of cases that the 1st server is searched (all)
# The second traversed server: total graph path len, number of cases that the 2nd server is searched (e.g., 30%), etc.
ave_graph_path_len_server = dict() # average path length of the k-th server searched



for i in range(query_num):
    # Count server path length
    total_path_length += path_len[i]
    if path_len[i] in len_count:
        len_count[path_len[i]] += 1
    else:
        len_count[path_len[i]] = 1
    if len(search_path_list[i]) > 1: search_remote_count += 1
        
    # Count graph path length per server
    for j, leng in enumerate(all_graph_path_len_list[i]):
        # j th searched server's graph path length is leng
        if j in graph_path_len_server:
            total_graph_path_len, case_count = graph_path_len_server[j]
        else:
            total_graph_path_len = 0 
            case_count = 0
        total_graph_path_len += leng
        case_count += 1
        graph_path_len_server[j] = (total_graph_path_len, case_count)
        
for i in graph_path_len_server:
    total_graph_path_len, case_count = graph_path_len_server[i]
    ave_graph_path_len_server[i] = total_graph_path_len / case_count
        
average_path_length = total_path_length / query_num
print("search remote rate: {} ({} cases)".format(search_remote_count/query_num, search_remote_count))
print("average path length: {}".format(average_path_length))
print("search path length distribution: {}".format(len_count))
print("average graph search path length on k-th searched sub-graph: {}".format(ave_graph_path_len_server))

search remote rate: 0.2976 (2976 cases)
average path length: 1.3024
search path length distribution: {2: 2928, 1: 7024, 3: 48}
average graph search path length on k-th searched sub-graph: {0: 133.1323, 1: 33.845766129032256, 2: 20.895833333333332}


## Partition-based search

Without distributed search. Using K-means to decide m partitions to search. Explore the relationship between m and recall.

In [None]:
""" Load hnswlib index """
all_server_IDs = np.arange(N_SUBGRAPH)
all_hnswlib_indexes = [hnswlib.Index(space='l2', dim=dim) for i in all_server_IDs]
parent_dir = '../indexes_subgraph_kmeans/SIFT1M_32_subgraphs'
all_index_paths=[os.path.join(parent_dir, 'subgraph_{}.bin'.format(i)) for i in all_server_IDs]
for i in all_server_IDs:
    print("\nLoading hnswlib index from {}\n".format(all_index_paths[i]))
    all_hnswlib_indexes[i].load_index(all_index_paths[i])

In [None]:
cluster_centers = kmeans.cluster_centers_
print(cluster_centers.shape, cluster_centers)

In [None]:
def compute_centroid_distances(cluster_centers, query_vecs):
    """
    Input:
        cluster_centers: 2-d array (num_clusters, dim)
        query_vecs: 2-d array (num_queries, dim)
    Output:
        distance_mat (num_queries, num_clusters),
            each element is a distance (L2 square)
    """
    num_clusters, dim = cluster_centers.shape
    nq = query_vecs.shape[0]
    assert dim == query_vecs.shape[1]
    
    distance_mat = np.zeros((nq, num_clusters))
    
    for i in range(num_clusters):
        centroid_replications = np.tile(cluster_centers[i], (nq,1))
        distance_mat[:, i] = np.sum((query_vecs - centroid_replications) ** 2, axis=1)
    
    return distance_mat

def kmeans_predict_sorted(cluster_centers, query_vecs):
    """
    Compute the cell centroid IDs for each query in a sorted manner 
        (increasing distance)
    
    Input:
        cluster_centers: 2-d array (num_clusters, dim)
        query_vecs: 2-d array (num_queries, dim)
    Output:
        ID_mat (num_queries, num_clusters),
            each element is a centroid ID 
    """
    num_clusters, dim = cluster_centers.shape
    nq = query_vecs.shape[0]
    
    distance_mat = compute_centroid_distances(cluster_centers, query_vecs)
    ID_mat = np.argsort(distance_mat, axis=1)
    
    return ID_mat

In [None]:
# Evaluate function correctness
print(kmeans_predict_sorted(cluster_centers, xq[:10]))
# query_kmeans_format = query_vec.reshape(1,-1).astype(np.float64)
partition_id = kmeans.predict(xq[:10].astype(np.float64))
print(partition_id)

In [None]:
sorted_partition_IDs = kmeans_predict_sorted(cluster_centers, xq)

In [None]:
sorted_partition_IDs.shape

In [None]:
""" Search several partitions per query vector """
MAX_VISITED_PARTITIONS = 8 # explore at most 
nq = xq.shape[0]
dim = xq.shape[1]
k = 100

all_I = np.zeros((nq, k * MAX_VISITED_PARTITIONS), dtype=int)
all_D = np.zeros((nq, k * MAX_VISITED_PARTITIONS))

for index in all_hnswlib_indexes: 
    index.set_ef(128)

for vec_id in range(nq):
    if vec_id % 1000 == 0: print("query id: ", vec_id)
    for j in range(MAX_VISITED_PARTITIONS):
        index_id = sorted_partition_IDs[vec_id][j]
        all_I[vec_id][j * k: (j + 1) * k], all_D[vec_id, j * k: (j + 1) * k] = \
            all_hnswlib_indexes[index_id].knn_query(xq[vec_id], k=k)

In [None]:
print(all_I.shape)

In [None]:
# Compute the recall combining {k} x {num_partitions (1~MAX_VISITED_PARTITIONS)}
for tmp_k in [1, 10, 100]:
    
    all_I_k_tmp = np.zeros((nq, tmp_k * MAX_VISITED_PARTITIONS), dtype=int)
    all_D_k_tmp = np.zeros((nq, tmp_k * MAX_VISITED_PARTITIONS))

    # copy tmp_k results from hnswlib
    for vec_id in range(nq):
        for j in range(MAX_VISITED_PARTITIONS):
            all_I_k_tmp[vec_id][j * tmp_k: (j + 1) * tmp_k] = all_I[vec_id][j * k: j * k + tmp_k]
            all_D_k_tmp[vec_id][j * tmp_k: (j + 1) * tmp_k] = all_D[vec_id][j * k: j * k + tmp_k]

    # for upto MAX_VISITED_PARTITIONS partition, compute recall
    for tmp_partition in range(1, 1 + MAX_VISITED_PARTITIONS):
        
        D_I_k_tmp = []
        for vec_id in range(nq):
            D_I_k_tmp.append([])
            for j in range(tmp_partition):
                for m in range(tmp_k):
                    D_I_k_tmp[vec_id].append((all_D_k_tmp[vec_id][j * tmp_k + m], all_I_k_tmp[vec_id][j * tmp_k + m]))

        D_k_tmp = []
        I_k_tmp = []
        for vec_id in range(nq):
            D_I_tmp = sorted(D_I_k_tmp[vec_id])[:tmp_k]
            D_k_tmp.append([D for D, I in D_I_tmp[:tmp_k]])
            I_k_tmp.append([I for D, I in D_I_tmp[:tmp_k]])
        
        print("Num partition = {}\tR@{} = {}".format(
            tmp_partition, tmp_k, recall_eval(result_list=I_k_tmp, gt=gt, k=tmp_k)))