In [1]:
import os
import sys
import time
import faiss
import numpy as np

sys.path.append("..")
from utils.vdb_utils import load_index, random_queries, random_floats, random_normal_vectors

In [2]:
# query index
def query_index(index, queries, k):
    D, I = index.search(queries, k)
    return D, I

def query_index_file(index_path, queries, k):
    index = load_index(index_path)
    D, I = query_index(index, queries, k)
    return D, I

def batch_query_index(index, queries, k, batch_size=1000):
    D = np.array([])
    I = np.array([])
    for i in range(0, len(queries), batch_size):
        D_batch, I_batch = query_index(index, queries[i:i+batch_size], k)
        D = np.concatenate((D, D_batch), axis=0) if D.size else D_batch
        I = np.concatenate((I, I_batch), axis=0) if I.size else I_batch
    return D, I

In [3]:
def search_outterloop_query(stopology, queries, idx_k, k, idx_paths):
    '''
    Search a batch of queries: looping over queries 

    args:
        - search topology: {qidx1: [idx1,...]}
        - query: (num, dim)
        - idx_k: k for each index search
        - k: top k results (global)
        - idx_paths: list of index paths
    '''
    # result_matrix: init ndarray with shape (num_queries, k)
    final_D_matrix = np.zeros((queries.shape[0], k))
    final_I_matrix = np.zeros((queries.shape[0], k))
    final_file_idx_matrix = np.zeros((queries.shape[0], k))

    # loop over queries
    for q_idx, idxs in stopology.items():
        # select query make shape (1, dim)
        query = queries[q_idx].reshape(1, -1)
        D_concat = np.array([])
        I_concat = np.array([])
        file_idx_concat = np.array([])

        # loop over idxs for each query
        for j, file_idx in enumerate(idxs):
            D, I = query_index_file(idx_paths[file_idx], query, idx_k)
            # make file_idx_matrix
            file_idx_m = np.ones_like(D) * file_idx
            D_concat = np.concatenate((D_concat, D), axis=1) if D_concat.size else D
            I_concat = np.concatenate((I_concat, I), axis=1) if I_concat.size else I
            file_idx_concat = np.concatenate((file_idx_concat, file_idx_m), axis=1) if file_idx_concat.size else file_idx_m

        # Sort Overwrite D_concat, I_concat, and file_idx_concat. 
        sort_idx = np.argsort(D_concat, axis=1)
        D_concat = np.take_along_axis(D_concat, sort_idx, axis=1)
        I_concat = np.take_along_axis(I_concat, sort_idx, axis=1)
        file_idx_concat = np.take_along_axis(file_idx_concat, sort_idx, axis=1)
        
        # update final matrix with top k
        final_D_matrix[q_idx] = D_concat[:, :k]
        final_I_matrix[q_idx] = I_concat[:, :k]
        final_file_idx_matrix[q_idx] = file_idx_concat[:, :k]
    return final_D_matrix, final_I_matrix.astype(int), final_file_idx_matrix.astype(int)
    

def batch_queries_by_stopology(stopology, queries):
    '''
    This function take in a outterloop index topology and a global query batch. 
    Return a list of batched queries for each index search.

    args:
        - search topology: {index1: [q1,q2...]}
        - query: (num, dim)

    return:
        - [query_batch1, query_batch2...]
            - query_batch: (num_queries, dim)
    '''
    query_batch_dict = {}
    for idx, q_idxs in stopology.items():
        query_batch = queries[q_idxs]
        query_batch_dict[idx] = query_batch
    return query_batch_dict

def search_outterloop_index(stopology, queries, idx_k, k, idx_paths):
    '''
    Search a batch of index: looping over index shards

    args:
        - search topology: {index1: [q1,q2...]}
        - queries: (num, dim)
        - idx_k: k for each index search
        - k: top k results (global)
        - idx_paths: list of index paths
    '''
    # result_matrix: init ndarray with shape (num_queries, k)
    final_D_matrix = np.ones((queries.shape[0], k)) * np.inf
    final_I_matrix = np.zeros((queries.shape[0], k))
    final_file_idx_matrix = np.zeros((queries.shape[0], k))

    # batch queries by stopology: {idx1, [q1_data, q2_data...]}
    stopology_queries_dict = batch_queries_by_stopology(stopology, queries)

    # loop over index shards
    for file_idx, q_idxs in stopology.items():
        # loop over q_idxs for each index
        query_batch = stopology_queries_dict[file_idx]
        # query_batch_order = q_idxs
        D, I = query_index_file(idx_paths[file_idx], query_batch, idx_k)
        file_idx_m = np.ones_like(D) * file_idx
        
        # merge and compare results in final matrix (D), then save top k
        prev_D = final_D_matrix[q_idxs]
        prev_I = final_I_matrix[q_idxs]
        prev_file_idx = final_file_idx_matrix[q_idxs]
        # merge and sort
        D_concat = np.concatenate((prev_D, D), axis=1)
        I_concat = np.concatenate((prev_I, I), axis=1)
        file_idx_concat = np.concatenate((prev_file_idx, file_idx_m), axis=1)
        # sort
        sort_idx = np.argsort(D_concat, axis=1)
        D_concat = np.take_along_axis(D_concat, sort_idx, axis=1)
        I_concat = np.take_along_axis(I_concat, sort_idx, axis=1)
        file_idx_concat = np.take_along_axis(file_idx_concat, sort_idx, axis=1)
        # update final matrix with top k
        final_D_matrix[q_idxs] = D_concat[:, :k]
        final_I_matrix[q_idxs] = I_concat[:, :k]
        final_file_idx_matrix[q_idxs] = file_idx_concat[:, :k]
    return final_D_matrix, final_I_matrix.astype(int), final_file_idx_matrix.astype(int)

In [23]:
# args
index_root = "../shards/idxs/"
nprobe = 20
# idx_paths = [os.path.join(index_root, f) for f in os.listdir(index_root)]
idx_paths = []
centriod_idx_paths = ""
for f in os.listdir(index_root):
    if "centroid" in f:
        centriod_idx_paths = os.path.join(index_root, f)
    else:
        idx_paths.append(os.path.join(index_root, f))


# some initializations
k = 5
dim = 64
num_shards = 50
num_queries = 100

# process queries (currently from the same distribution)
# queries = random_queries(num_queries, dim)
random_mean = random_floats(1)
random_std = random_floats(1)
queries =  random_normal_vectors(num_queries, dim, random_mean, random_std)

# knn find top centroids
idx_k = (k // nprobe) + k
print(idx_k)

D, I = query_index_file(centriod_idx_paths, queries, nprobe)
print(I)
print(I.shape)

5
[[10 46 19 ... 49  0 13]
 [ 7  5 20 ... 12  8  0]
 [43 12 31 ... 28 41  7]
 ...
 [10 19 21 ... 23 49 45]
 [36 28 46 ... 13  8  0]
 [10 46 21 ... 49 13  0]]
(100, 20)


# Test outterloop query

In [24]:
# make search topology (loop over queries)
stopology = {}
for i in range(num_queries):
    stopology[i] = list(I[i])
print(stopology)
print()

# for qi, idxs in stopology.items():
final_D_matrix, final_I_matrix, final_file_idx_matrix = search_outterloop_query(stopology, queries, idx_k, k, idx_paths)

{0: [10, 46, 19, 21, 25, 28, 36, 7, 15, 26, 20, 31, 5, 43, 12, 8, 23, 49, 0, 13], 1: [7, 5, 20, 36, 28, 10, 46, 21, 19, 25, 23, 49, 15, 13, 26, 31, 43, 12, 8, 0], 2: [43, 12, 31, 26, 8, 15, 0, 25, 21, 19, 9, 45, 3, 46, 10, 11, 36, 28, 41, 7], 3: [15, 26, 31, 25, 21, 19, 46, 43, 10, 12, 8, 36, 28, 7, 0, 20, 45, 5, 9, 3], 4: [10, 46, 36, 28, 19, 21, 7, 25, 20, 5, 15, 26, 31, 43, 12, 23, 49, 8, 13, 0], 5: [21, 25, 19, 10, 46, 15, 26, 31, 36, 28, 7, 43, 20, 12, 5, 8, 0, 45, 9, 3], 6: [28, 36, 7, 20, 10, 5, 46, 19, 21, 25, 15, 26, 31, 23, 49, 43, 13, 12, 8, 0], 7: [31, 26, 15, 25, 21, 43, 19, 12, 46, 10, 8, 36, 0, 28, 7, 9, 45, 3, 20, 5], 8: [10, 46, 21, 25, 19, 15, 26, 36, 28, 31, 7, 20, 5, 43, 12, 8, 0, 45, 23, 9], 9: [36, 28, 10, 7, 46, 20, 19, 5, 21, 25, 15, 26, 31, 23, 43, 49, 12, 13, 8, 0], 10: [46, 19, 21, 10, 25, 36, 28, 15, 26, 7, 31, 20, 5, 43, 12, 8, 0, 23, 49, 45], 11: [10, 46, 19, 21, 25, 36, 28, 7, 15, 26, 20, 31, 5, 43, 12, 8, 23, 49, 0, 13], 12: [36, 28, 7, 10, 46, 20, 19, 5

In [25]:
print(final_D_matrix)
print()
print(final_I_matrix)
print()
print(final_file_idx_matrix)

[[ 6.52106762  6.78923607  6.85729599  6.88746643  6.89422703]
 [10.40779495 11.11463928 11.24617481 11.29669857 11.33357239]
 [ 6.08867359  6.15517426  6.28098011  6.3227849   6.35271835]
 [ 5.80109215  6.12573719  6.50049591  6.59513569  6.7354126 ]
 [ 4.9363637   5.17505026  5.18827772  5.24476576  5.29479122]
 [ 6.37004566  6.44735241  6.70378351  6.74477768  6.79832935]
 [ 8.31094074  8.8658905   9.32531166  9.52775192  9.52835655]
 [ 7.84479332  7.90951014  7.93931913  8.13201523  8.31669044]
 [ 5.7032423   5.79676151  5.90681124  6.00833893  6.01344633]
 [ 5.80945301  6.10432768  6.16604519  6.31072044  6.39087677]
 [ 7.15524387  7.20911884  7.30800867  7.52521181  7.59423018]
 [ 7.27516222  8.07793617  8.087286    8.33469772  8.48258305]
 [ 6.42051935  7.10861301  7.17304468  7.23983669  7.24585247]
 [ 4.92830276  5.29168081  5.32107639  5.4695282   5.49585199]
 [ 8.71671867  9.32459641  9.36468792  9.44102764  9.45532227]
 [ 9.70161819 10.44580078 10.52528381 10.72315025 10.74

# Test outterloop index

In [26]:
idx_query_stopology = {}
for index, i in enumerate(I):
    for j in i:
        if j not in idx_query_stopology:
            idx_query_stopology[j] = [index]
        else:
            idx_query_stopology[j].append(index)

print(idx_query_stopology)

{10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 46: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 19: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 

In [27]:
final_D_matrix2, final_I_matrix2, final_file_idx_matrix2 = search_outterloop_index(idx_query_stopology, queries, idx_k, k, idx_paths)

print(final_D_matrix2)
print()
print(final_I_matrix2)
print()
print(final_file_idx_matrix2)

[[ 6.52106762  6.78923607  6.85729599  6.88746643  6.89422703]
 [10.40779495 11.11463928 11.24617481 11.29669857 11.33357239]
 [ 6.08867359  6.15517426  6.28098011  6.3227849   6.35271835]
 [ 5.80109215  6.12573719  6.50049591  6.59513569  6.7354126 ]
 [ 4.9363637   5.17505026  5.18827772  5.24476576  5.29479122]
 [ 6.37004566  6.44735241  6.70378351  6.74477768  6.79832935]
 [ 8.31094074  8.8658905   9.32531166  9.52775192  9.52835655]
 [ 7.84479332  7.90951014  7.93931913  8.13201523  8.31669044]
 [ 5.7032423   5.79676151  5.90681124  6.00833893  6.01344633]
 [ 5.80945301  6.10432768  6.16604519  6.31072044  6.39087677]
 [ 7.15524387  7.20911884  7.30800867  7.52521181  7.59423018]
 [ 7.27516222  8.07793617  8.087286    8.33469772  8.48258305]
 [ 6.42051935  7.10861301  7.17304468  7.23983669  7.24585247]
 [ 4.92830276  5.29168081  5.32107639  5.4695282   5.49585199]
 [ 8.71671867  9.32459641  9.36468792  9.44102764  9.45532227]
 [ 9.70161819 10.44580078 10.52528381 10.72315025 10.74

In [22]:
# compare two results
print(np.array_equal(final_D_matrix, final_D_matrix2))
print(np.array_equal(final_I_matrix, final_I_matrix2))
print(np.array_equal(final_file_idx_matrix, final_file_idx_matrix2))

True
True
True


# Intellegent batching (Approximate)

1. Cluster the queries into groups, then find knn centroids for each group

In [None]:
def queries_knn_clustering(queries, num_clusters, niter=10, verbose=False):
    '''
    This function split batch of queries into num_clusters of clusters (sub-batches) based on some clustering algorithm.
    '''
    # clustering algorithm
    kmeans = faiss.Kmeans(queries.shape[1], num_clusters, niter=niter, verbose=verbose)
    kmeans.train(queries)
    D, I = kmeans.index.search(queries, 1)
    clusters = {}
    for i in range(num_clusters):
        clusters[i] = []
    for i, c in enumerate(I):
        clusters[c[0]].append(i)
    return clusters

def reverse_stopology(stopology):
    '''
    This function reverse the stopology dict.
    '''
    reverse_dict = {}
    for key, val_list in stopology.items():
        for val in val_list:
            if val not in reverse_dict:
                reverse_dict[val] = [key]
            else:
                reverse_dict[val].append(key)
    return reverse_dict

def stopology_overlap(stopology):
    '''
    This function check the overlap of the values in the stopology dict.
    '''
    # overlap dict: {idx1: [idx2, idx3...]}
    overlap_dict = {}

In [None]:
# test queries_knn_clustering
num_clusters = 5
clusters = queries_knn_clustering(queries, num_clusters)
print(clusters)

In [None]:
idx_query_stopology_len = {k: len(v) for k, v in idx_query_stopology.items()}

# sort by length
idx_query_stopology_len = dict(sorted(idx_query_stopology_len.items(), key=lambda item: item[1], reverse=False))
idx_query_stopology_len

In [83]:
def random_queries_mix_distribs(num_queries, dim, mixtures_ratio=1, low=0, high=1):
    '''
    This function generates random queries draw from a mix of distributions.

    Args:
        - num_queries: number of queries to generate
        - dim: dimensionality of the queries
        - low: lower bound of the uniform distribution (both mean and std)
        - high: upper bound of the uniform distribution (both mean and std)
        - mixtures_ratio [0,1]: 
            - 0 means only one distribution. 
            - 1 means every queries are drawn from different distributions.
            - 0.1 means per 10% of the query_batch are drawn from one distribution. 
    '''
    queries = np.zeros((num_queries, dim))

    # compute the number of queries for each distribution
    if mixtures_ratio == 0: 
        random_mean = random_floats(1, low, high)[0]
        random_std = random_floats(1, low, high)[0]
        return random_normal_vectors(num_queries, dim, random_mean, random_std)
    elif mixtures_ratio == 1:
        for i in range(num_queries):
            random_mean = random_floats(1, low, high)[0]
            random_std = random_floats(1, low, high)[0]
            # random_std = 0.
            queries[i] = random_normal_vectors(1, dim, random_mean, random_std)
        return queries
    else:
        # generate random queries per sample_size
        sample_size = int(num_queries * mixtures_ratio)
        for i in range(num_queries):
            if i % sample_size == 0:
                random_mean = random_floats(1, low, high)[0]
                random_std = random_floats(1, low, high)[0]
                # random_std = 0.
                # print(random_mean)
            queries[i] = random_normal_vectors(1, dim, random_mean, random_std)
    
    return queries

In [85]:
random_queries_mix_distribs(10, 5, 0, 1, 1)

array([[ 1.4870925e+00,  6.6816896e-01, -4.2996353e-01,  1.2595263e+00,
         1.5886284e+00],
       [ 1.0175064e+00,  1.5429906e+00,  1.6057872e+00,  1.6061089e+00,
         1.4111378e+00],
       [ 2.3599353e-03, -1.2380664e+00,  2.2358882e+00,  4.4222981e-01,
         8.7424690e-01],
       [ 1.4135383e+00, -6.5668362e-01,  1.1655186e+00,  2.6856822e-01,
        -4.4888431e-01],
       [-1.0438749e+00,  1.9262393e+00,  2.7522919e+00,  7.5073594e-01,
         1.9149030e+00],
       [ 9.5477885e-01,  2.6365659e+00,  3.0313277e-01,  2.3647342e+00,
        -3.4463525e-01],
       [ 2.2486997e+00,  1.3892709e+00,  1.2398897e+00,  1.9064709e+00,
         5.6471306e-01],
       [ 1.4362975e+00,  5.9811503e-01, -1.1513985e+00, -1.7313069e-01,
        -4.1947167e-02],
       [-1.8666120e-01,  2.8301976e+00,  1.9496770e+00,  8.2169998e-01,
         1.1585656e-01],
       [ 1.8699143e+00,  1.8309375e+00,  2.0125000e+00,  1.2244114e+00,
         1.2772148e+00]], dtype=float32)

In [71]:
random_queries_mix_distribs(10, 5, 0, 1, 0)

array([[ 0.25678778,  0.33809626,  0.20478295,  1.9232888 ,  0.2280871 ],
       [ 1.5397611 , -0.4476633 ,  0.1596234 ,  1.5886568 ,  0.13145635],
       [ 1.1333157 ,  0.89155096,  0.618855  ,  0.34158555,  0.84976345],
       [-0.17071183,  1.8930807 ,  0.12922935,  1.7087269 ,  2.4321628 ],
       [ 1.6071975 ,  1.4515423 ,  1.6974658 ,  1.4345169 ,  0.21072517],
       [-0.4298537 , -0.81705785,  0.49031326,  1.0953708 ,  0.8020883 ],
       [-0.61599505,  0.28845945, -0.19621788,  0.2473756 ,  1.2683831 ],
       [ 0.38506258, -0.3766805 ,  1.35082   ,  0.41475973,  0.32497922],
       [ 0.48606375,  0.72981566,  1.4222623 ,  2.0935383 ,  0.7149932 ],
       [ 0.48011962,  0.9305447 , -0.681897  ,  0.21601701,  2.7042222 ]],
      dtype=float32)