In [2]:
import scann
import h5py
import numpy as np
import time
import puffinn

In [3]:
with h5py.File('lsh/.glove-25-angular.hdf5') as hf:
    dataset = hf['/train'][:]
    dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
    queries  = hf['/test'][:]
    #queries = queries / np.linalg.norm(queries, axis=1)[:, np.newaxis]
    ground = hf['/neighbors'][:, :10]

In [4]:
def compute_recall(neighbors, true_neighbors):
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size


In [5]:
searcher = scann.scann_ops_pybind.builder(dataset, 10, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

2023-05-04 15:36:38.073539: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 249797
2023-05-04 15:36:39.488436: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:88] PartitionerFactory ran in 1.414834134s.


In [6]:
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=100)
end = time.time()

print("Recall:", compute_recall(neighbors, ground))
print("Time:", end - start)
print("Time per query:", 1000*((end - start) / queries.shape[0]), "ms")

Recall: 0.80583
Time: 0.6993043422698975
Time per query: 0.06993043422698975 ms


In [10]:
query_idx = 1

In [11]:
%timeit searcher.search(queries[query_idx,:], leaves_to_search=1000)

328 µs ± 6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
r = 0
leaves_to_search = 100
while r < 0.5:
    leaves_to_search += 100
    r = compute_recall(searcher.search(queries[query_idx,:], leaves_to_search=leaves_to_search)[0], ground[query_idx,:])
r, leaves_to_search

KeyboardInterrupt: 

In [40]:
brute = scann.scann_ops_pybind.builder(dataset, 10, "dot_product").score_brute_force().build()

In [41]:
start = time.time()
neighbors, distances = brute.search_batched(queries)
end = time.time()

print("Recall:", compute_recall(neighbors, ground))
print("Time:", end - start)

Recall: 1.0
Time: 4.924302577972412


## PUFFINN

In [18]:
# Construct the index using the cosine similarity measure,
# the default hash functions and 4 GB of memory.
index = puffinn.Index('angular', dataset.shape[1], 4*1024**3)
for v in dataset:
    index.insert(list(v))
index.rebuild()

Starting index_build
Building sketches
Number of tables: 380
Done index_build in 139046 ms


In [25]:
start = time.time()
r = 0 
for i, q in enumerate(queries):
    ans = index.search(list(q), 10, 0.8)
    r += compute_recall(ans, ground[i])
end = time.time()

print("Recall:", r / queries.shape[0])
print("Time:", end - start)

Recall: 0.5609600000000018
Time: 18.556936502456665
