In [96]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import normalize
import os
import time

In [201]:
config = 'multihead/'
root_dir = "/mnt/disks/retrieval/corrected_fwd_pass/"+config
index = 'exactl2'
dataset = 'imagenet1k'
use_cascade = 0

retrieval_dim = 16 # scale at which to retrieve 100-NN for all samples in query set
max_retrieval_dim = 2048
max_rerank_dim = 2048
rerank_dim = [2048] # scale at which neighbors will be re-ordered based on L2 distance
#rerank_dim = [16,32,64,128,2048]
#shortlist = [200,100,50,25,10] # corresponding shortlist length for reranking
shortlist = [100]

In [204]:
# Load knn array, database vectors, and query vectors
db_csv = dataset+'_train_nesting1_sh0_ff2048-X.npy'
query_csv = dataset+'_val_nesting1_sh0_ff2048-X.npy'

start = time.time()
db_rerank = np.load(root_dir+db_csv)[:, :max_rerank_dim]
end = time.time() - start
print("Load database vectors (%d x %d), time= %f" % (db_rerank.shape[0], db_rerank.shape[1], end))

start = time.time()
queries = np.load(root_dir+query_csv)[:, :max_rerank_dim]
end = time.time() - start
print("Load query vectors (%d x %d), time= %f" % (queries.shape[0], queries.shape[1], end))

start = time.time()
queries = normalize(queries, axis=1)
db_rerank = normalize(db_rerank, axis=1)
end = time.time() - start
print("Normalization time= %f" % (end))

Load database vectors (1281167 x 2048), time= 60.301246
Load query vectors (50000 x 2048), time= 0.987897
Normalization time= 188.171419


## Modify below to avoid expensive file loads for 4M dataset

In [205]:
start = time.time()
NN_file = root_dir+"neighbors/"+index+"_"+ str(retrieval_dim)+"dim-2048-NN_"+dataset+".csv"
neighbors = pd.read_csv(NN_file, header=None).to_numpy()

end = time.time() - start
print("Loaded %s : (%d x %d), time= %f" % (NN_file.split("/")[-1], neighbors.shape[0], neighbors.shape[1], end))

# db_retrieval = db_rerank[:, :max_retrieval_dim]
# db_retrieval = normalize(db_retrieval, axis=1)

Loaded exactl2_16dim-2048-NN_imagenet1k.csv : (50000 x 2048), time= 12.993581


In [206]:
print("\nDB for reranking: ", db_rerank.shape)
print("Queries for reranking: ", queries.shape)
print("k-NN array: ", neighbors.shape)


DB for reranking:  (1281167, 2048)
Queries for reranking:  (50000, 2048)
k-NN array:  (50000, 2048)


# Naive Routing/Cascading Strategy

In [207]:
def naive_rerank(use_cascade, rerank_dim, shortlist, neighbors):
    
    # ensure these match for naive routing strategy
    if use_cascade:
        assert len(rerank_dim) == len(shortlist)

    for i in range(len(rerank_dim)):
        db_rerank_new = db_rerank[:, :rerank_dim[i]]
        neighbors_new = neighbors[:, :shortlist[i]]

        # iterate over every query and re-order 100-NN based on 2048 dim distances
        for j in range(len(neighbors)):
        #for j in range(2):
            query_vector = queries[j][:rerank_dim[i]]
            #print("Query vector: ", query_vector.shape)
            nn_indices = neighbors_new[j][:shortlist[i]]

            #NN_vectors_original = normalize(db_retrieval[nn_indices].squeeze(), axis = 1)
            NN_vectors_higher_dim = normalize(db_rerank_new[nn_indices].squeeze(), axis=1)
            #print("NN vector original and higher dim: ", NN_vectors_original.shape, NN_vectors_higher_dim.shape)

            #L2_distances_orig = np.linalg.norm(NN_vectors_original - query_vector[:retrieval_dim], axis=1)
            #print("Sorting at retrieval dim: ", np.argsort(L2_distances_orig)[:10]) #sanity test this should be 0, 1, 2 ...
            L2_distances_reranked = np.linalg.norm(NN_vectors_higher_dim - query_vector[:rerank_dim[i]], axis=1)
            #print("Sorting at rerank dim: ", np.argsort(L2_distances_reranked)[:10]) #reorder indices based on higher dim representations

            reranked_neighbor_indices = np.argsort(L2_distances_reranked)
            reranked_neighbors = neighbors_new[j, reranked_neighbor_indices]
            neighbors_new[j] = reranked_neighbors
        #print("DB rerank: ", db_rerank_new.shape)
        #print("Neighbors: ", neighbors_new.shape)
        neighbors = neighbors_new
    return neighbors

## Rerank over rerank_dim list for fixed shortlist length k. Retrieval dim is also fixed and loaded from NN.csv 

In [211]:
shortlist_list = [[200]]

In [212]:
for shortlist in shortlist_list:
    for dim in rerank_dim:
        start = time.time()
        neighbors_reranked = naive_rerank(use_cascade, [dim], shortlist, neighbors)
        end = time.time() - start
        print("\nRetrieve @%d + rerank@%d, time = %f" % (retrieval_dim, dim, end))

        neighbors_df = pd.DataFrame(neighbors_reranked)
        print(neighbors_df.shape)

        if not use_cascade:
            nn_dir = root_dir+"neighbors/reranked/"
        else:
            nn_dir = root_dir+"neighbors/cascade_naive_policy/"

        if not os.path.isdir(nn_dir):
            os.makedirs(nn_dir)

        filename = str(retrieval_dim)+"dim-reranked"+str(dim)+"_"+str(shortlist[0])+"shortlist_"+dataset+"_"+index+".csv"

        print("Saving config: ", filename)
        #pd.DataFrame(neighbors_df).to_csv(nn_dir+filename, header=None, index=None)


Retrieve @16 + rerank@2048, time = 83.435824
(50000, 200)
Saving config:  16dim-reranked2048_200shortlist_imagenet1k_exactl2.csv


## Funnel Retrieval (increase dims and reduce shortlist length in sync)

In [190]:
shortlist_set = [[800,400,200,50,10], [400,200,50,25,10], [200,100,50,25,10]]
# corresponding shortlist length for reranking

In [191]:
for shortlist in shortlist_set:
    start = time.time()
    NN_cascade = naive_rerank(1, rerank_dim, shortlist, neighbors)
    end = time.time() - start
    print("\nRetrieve @%d + cascade naive policy @%s with shortlist %s, time = %f" 
          % (retrieval_dim, rerank_dim, shortlist, end))

    neighbors_df = pd.DataFrame(NN_cascade)
    print(neighbors_df.shape)

    nn_dir = root_dir+"neighbors/cascade_naive_policy/"
    filename = str(retrieval_dim)+"dim-cascade"+str(rerank_dim)+"_"+str(shortlist)+"shortlist_"+dataset+"_"+index+".csv"

    print("Saving config: ", filename)
    pd.DataFrame(neighbors_df).to_csv(nn_dir+filename, header=None, index=None)


Retrieve @8 + cascade naive policy @[16, 32, 64, 128, 2048] with shortlist [800, 400, 200, 50, 10], time = 245.806595
(210100, 10)
Saving config:  8dim-cascade[16, 32, 64, 128, 2048]_[800, 400, 200, 50, 10]shortlist_imagenet4m_exactl2.csv

Retrieve @8 + cascade naive policy @[16, 32, 64, 128, 2048] with shortlist [400, 200, 50, 25, 10], time = 176.763507
(210100, 10)
Saving config:  8dim-cascade[16, 32, 64, 128, 2048]_[400, 200, 50, 25, 10]shortlist_imagenet4m_exactl2.csv

Retrieve @8 + cascade naive policy @[16, 32, 64, 128, 2048] with shortlist [200, 100, 50, 25, 10], time = 152.884837
(210100, 10)
Saving config:  8dim-cascade[16, 32, 64, 128, 2048]_[200, 100, 50, 25, 10]shortlist_imagenet4m_exactl2.csv


In [157]:
print("Saving config: ", filename)
pd.DataFrame(neighbors_df).to_csv(nn_dir+filename, header=None, index=None)

Saving config:  64dim-cascade[128, 256, 512, 1024, 2048]_[200, 100, 50, 25, 10]shortlist_imagenet4m_exactl2.csv
