In [1]:
import os
import sys
import time
import faiss
import numpy as np

sys.path.append("..")
from utils.vdb_utils import load_index, random_queries

In [2]:
# query index
def query_index(index, queries, k):
    D, I = index.search(queries, k)
    return D, I

def query_index_file(index_path, queries, k):
    index = load_index(index_path)
    D, I = query_index(index, queries, k)
    return D, I

def batch_query_index(index, queries, k, batch_size=1000):
    D = np.array([])
    I = np.array([])
    for i in range(0, len(queries), batch_size):
        D_batch, I_batch = query_index(index, queries[i:i+batch_size], k)
        D = np.concatenate((D, D_batch), axis=0) if D.size else D_batch
        I = np.concatenate((I, I_batch), axis=0) if I.size else I_batch
    return D, I

In [101]:
def search_outterloop_query(stopology, queries, idx_k, k, idx_paths):
    '''
    Search a batch of queries: looping over queries 

    args:
        - search topology: {qidx1: [idx1,...]}
        - query: (num, dim)
        - idx_k: k for each index search
        - k: top k results (global)
        - idx_paths: list of index paths
    '''
    # result_matrix: init ndarray with shape (num_queries, k)
    final_D_matrix = np.zeros((queries.shape[0], k))
    final_I_matrix = np.zeros((queries.shape[0], k))
    final_file_idx_matrix = np.zeros((queries.shape[0], k))

    # loop over queries
    for q_idx, idxs in stopology.items():
        # select query make shape (1, dim)
        query = queries[q_idx].reshape(1, -1)
        D_concat = np.array([])
        I_concat = np.array([])
        file_idx_concat = np.array([])

        # loop over idxs for each query
        for j, file_idx in enumerate(idxs):
            D, I = query_index_file(idx_paths[file_idx], query, idx_k)
            # make file_idx_matrix
            file_idx_m = np.ones_like(D) * file_idx
            D_concat = np.concatenate((D_concat, D), axis=1) if D_concat.size else D
            I_concat = np.concatenate((I_concat, I), axis=1) if I_concat.size else I
            file_idx_concat = np.concatenate((file_idx_concat, file_idx_m), axis=1) if file_idx_concat.size else file_idx_m

        # Sort Overwrite D_concat, I_concat, and file_idx_concat. 
        sort_idx = np.argsort(D_concat, axis=1)
        D_concat = np.take_along_axis(D_concat, sort_idx, axis=1)
        I_concat = np.take_along_axis(I_concat, sort_idx, axis=1)
        file_idx_concat = np.take_along_axis(file_idx_concat, sort_idx, axis=1)
        
        # update final matrix with top k
        final_D_matrix[q_idx] = D_concat[:, :k]
        final_I_matrix[q_idx] = I_concat[:, :k]
        final_file_idx_matrix[q_idx] = file_idx_concat[:, :k]
    return final_D_matrix, final_I_matrix.astype(int), final_file_idx_matrix.astype(int)
    

def batch_queries_by_stopology(stopology, queries):
    '''
    This function take in a outterloop index topology and a global query batch. 
    Return a list of batched queries for each index search.

    args:
        - search topology: {index1: [q1,q2...]}
        - query: (num, dim)

    return:
        - [query_batch1, query_batch2...]
            - query_batch: (num_queries, dim)
    '''
    query_batch_dict = {}
    for idx, q_idxs in stopology.items():
        query_batch = queries[q_idxs]
        query_batch_dict[idx] = query_batch
    return query_batch_dict

def search_outterloop_index(stopology, queries, idx_k, k, idx_paths):
    '''
    Search a batch of index: looping over index shards

    args:
        - search topology: {index1: [q1,q2...]}
        - queries: (num, dim)
        - idx_k: k for each index search
        - k: top k results (global)
        - idx_paths: list of index paths
    '''
    # result_matrix: init ndarray with shape (num_queries, k)
    final_D_matrix = np.ones((queries.shape[0], k)) * np.inf
    final_I_matrix = np.zeros((queries.shape[0], k))
    final_file_idx_matrix = np.zeros((queries.shape[0], k))

    # batch queries by stopology: {idx1, [q1_data, q2_data...]}
    stopology_queries_dict = batch_queries_by_stopology(stopology, queries)

    # loop over index shards
    for file_idx, q_idxs in stopology.items():
        # loop over q_idxs for each index
        query_batch = stopology_queries_dict[file_idx]
        # query_batch_order = q_idxs
        D, I = query_index_file(idx_paths[file_idx], query_batch, idx_k)
        file_idx_m = np.ones_like(D) * file_idx
        
        # merge and compare results in final matrix (D), then save top k
        prev_D = final_D_matrix[q_idxs]
        prev_I = final_I_matrix[q_idxs]
        prev_file_idx = final_file_idx_matrix[q_idxs]
        # merge and sort
        D_concat = np.concatenate((prev_D, D), axis=1)
        I_concat = np.concatenate((prev_I, I), axis=1)
        file_idx_concat = np.concatenate((prev_file_idx, file_idx_m), axis=1)
        # sort
        sort_idx = np.argsort(D_concat, axis=1)
        D_concat = np.take_along_axis(D_concat, sort_idx, axis=1)
        I_concat = np.take_along_axis(I_concat, sort_idx, axis=1)
        file_idx_concat = np.take_along_axis(file_idx_concat, sort_idx, axis=1)
        # update final matrix with top k
        final_D_matrix[q_idxs] = D_concat[:, :k]
        final_I_matrix[q_idxs] = I_concat[:, :k]
        final_file_idx_matrix[q_idxs] = file_idx_concat[:, :k]
    return final_D_matrix, final_I_matrix.astype(int), final_file_idx_matrix.astype(int)

In [102]:
# args
index_root = "../shards/idxs/"
nprobe = 10
# idx_paths = [os.path.join(index_root, f) for f in os.listdir(index_root)]
idx_paths = []
centriod_idx_paths = ""
for f in os.listdir(index_root):
    if "centroid" in f:
        centriod_idx_paths = os.path.join(index_root, f)
    else:
        idx_paths.append(os.path.join(index_root, f))


# some initializations
k = 5
dim = 64
num_shards = 50
num_queries = 5

# process queries
queries = random_queries(num_queries, dim)

# knn find top centroids
idx_k = (k // nprobe) + k
print(idx_k)

D, I = query_index_file(centriod_idx_paths, queries, nprobe)
print(I)
print(I.shape)

5
[[15 10 47 35  6 16  0 27 20  5]
 [36 17 28 37  6  8 33 10  0 43]
 [ 6 32 44 35 23 40  3 43 34  4]
 [ 6 38 46 17 30 11 41 25 33  0]
 [ 0 17 19 42 34 11 30 15 45 49]]
(5, 10)


# Test outterloop query

In [103]:
# make search topology (loop over queries)
stopology = {}
for i in range(num_queries):
    stopology[i] = list(I[i])
print(stopology)
print()

# for qi, idxs in stopology.items():
final_D_matrix, final_I_matrix, final_file_idx_matrix = search_outterloop_query(stopology, queries, idx_k, k, idx_paths)

{0: [15, 10, 47, 35, 6, 16, 0, 27, 20, 5], 1: [36, 17, 28, 37, 6, 8, 33, 10, 0, 43], 2: [6, 32, 44, 35, 23, 40, 3, 43, 34, 4], 3: [6, 38, 46, 17, 30, 11, 41, 25, 33, 0], 4: [0, 17, 19, 42, 34, 11, 30, 15, 45, 49]}



In [108]:
print(final_D_matrix)
print()
print(final_I_matrix)
print()
print(final_file_idx_matrix)

[[5.28747559 6.3287611  6.49794102 6.5273447  6.56514645]
 [4.83784533 5.56429386 5.80746508 5.8135066  5.87090874]
 [5.14307261 5.3457408  5.36828613 5.49654913 5.62543297]
 [5.0963521  5.26672077 5.27462673 5.27968597 5.34793758]
 [5.09117603 5.09383678 5.17984581 5.2782383  5.3036623 ]]

[[ 4248 87728  2827 19288 37462]
 [50265  7226 57973 49381 17145]
 [13546 43401 92565 80600 15484]
 [14062 44502 17928 98982 26072]
 [ 5584 19312 93331 29457 36489]]

[[ 5 15  5 35 35]
 [33  8  0 17  6]
 [ 3 34  3  6 34]
 [41 33 11 33 30]
 [ 0 15 11 45 11]]


# Test outterloop index

In [109]:
idx_query_stopology = {}
for index, i in enumerate(I):
    for j in i:
        if j not in idx_query_stopology:
            idx_query_stopology[j] = [index]
        else:
            idx_query_stopology[j].append(index)

print(idx_query_stopology)

{15: [0, 4], 10: [0, 1], 47: [0], 35: [0, 2], 6: [0, 1, 2, 3], 16: [0], 0: [0, 1, 3, 4], 27: [0], 20: [0], 5: [0], 36: [1], 17: [1, 3, 4], 28: [1], 37: [1], 8: [1], 33: [1, 3], 43: [1, 2], 32: [2], 44: [2], 23: [2], 40: [2], 3: [2], 34: [2, 4], 4: [2], 38: [3], 46: [3], 30: [3, 4], 11: [3, 4], 41: [3], 25: [3], 19: [4], 42: [4], 45: [4], 49: [4]}


In [110]:
final_D_matrix2, final_I_matrix2, final_file_idx_matrix2 = search_outterloop_index(idx_query_stopology, queries, idx_k, k, idx_paths)

print(final_D_matrix2)
print()
print(final_I_matrix2)
print()
print(final_file_idx_matrix2)

[[5.28747559 6.3287611  6.49794102 6.5273447  6.56514645]
 [4.83784533 5.56429386 5.80746508 5.8135066  5.87090874]
 [5.14307261 5.3457408  5.36828613 5.49654913 5.62543297]
 [5.0963521  5.26672077 5.27462673 5.27968597 5.34793758]
 [5.09117603 5.09383678 5.17984581 5.2782383  5.3036623 ]]

[[ 4248 87728  2827 19288 37462]
 [50265  7226 57973 49381 17145]
 [13546 43401 92565 80600 15484]
 [14062 44502 17928 98982 26072]
 [ 5584 19312 93331 29457 36489]]

[[ 5 15  5 35 35]
 [33  8  0 17  6]
 [ 3 34  3  6 34]
 [41 33 11 33 30]
 [ 0 15 11 45 11]]


In [112]:
# compare two results
print(np.array_equal(final_D_matrix, final_D_matrix2))
print(np.array_equal(final_I_matrix, final_I_matrix2))
print(np.array_equal(final_file_idx_matrix, final_file_idx_matrix2))

True
True
True
