In [None]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ScaNN Demo with GloVe Dataset

In [1]:
import numpy as np
import h5py
import os
import requests
import tempfile
import time

import scann

2023-10-16 15:23:01.598729: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-16 15:23:01.633709: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-16 15:23:01.635014: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Download dataset

In [3]:
with tempfile.TemporaryDirectory() as tmp:
    response = requests.get("http://ann-benchmarks.com/glove-100-angular.hdf5")
    loc = os.path.join(tmp, "glove.hdf5")
    with open(loc, 'wb') as f:
        f.write(response.content)
    
    glove_h5py = h5py.File(loc, "r")

In [4]:
list(glove_h5py.keys())

['distances', 'neighbors', 'test', 'train']

In [5]:
dataset = glove_h5py['train']
queries = glove_h5py['test']
print(dataset.shape)
print(queries.shape)

(1183514, 100)
(10000, 100)


In [10]:
type(dataset)

h5py._hl.dataset.Dataset

In [11]:
type(queries)

h5py._hl.dataset.Dataset

### Create ScaNN searcher

In [12]:
normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
# configure ScaNN as a tree - asymmetric hash hybrid with reordering
# anisotropic quantization as described in the paper; see README

# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher
searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

2023-10-16 15:32:47.334315: I scann/partitioning/partitioner_factory_base.cc:59] Size of sampled dataset for training partition: 249797
2023-10-16 15:32:52.959762: I ./scann/partitioning/kmeans_tree_partitioner_utils.h:84] PartitionerFactory ran in 5.625388093s.


In [13]:
type(normalized_dataset)

numpy.ndarray

In [14]:
normalized_dataset.shape

(1183514, 100)

In [19]:
len(normalized_dataset[0])

100

In [20]:
normalized_dataset[0]

array([-0.02701984,  0.11539876,  0.02164138, -0.05349847,  0.0081553 ,
       -0.13311078,  0.00997753, -0.12772731,  0.04484392, -0.14000343,
        0.00365088, -0.00347016,  0.19274133, -0.0091836 ,  0.17964269,
        0.16808899, -0.04258849,  0.07681806,  0.16111052,  0.16021168,
        0.06209341,  0.09982534, -0.08135276,  0.0545022 , -0.12762241,
        0.2999767 , -0.02182544,  0.04700636, -0.00892968, -0.07953603,
        0.07486065,  0.08699372,  0.16990334,  0.03116115, -0.05877941,
       -0.12503797, -0.00860472,  0.13129166,  0.02388226,  0.11466681,
        0.16952427, -0.01274627,  0.05322667,  0.07371149, -0.09519051,
        0.00873418, -0.08447365, -0.10203069,  0.11073054,  0.06100146,
        0.16273652, -0.04964088,  0.09163094,  0.01329725, -0.06053416,
       -0.04960034,  0.12522155, -0.02717719, -0.07755716, -0.10515159,
        0.04178979,  0.14842671,  0.11977372, -0.18136406, -0.01711503,
        0.00191046, -0.03167613,  0.11943993,  0.04487968, -0.13

In [21]:
np.linalg.norm(normalized_dataset[0])

1.0

In [6]:
def compute_recall(neighbors, true_neighbors):
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size

### ScaNN interface features

In [7]:
# this will search the top 100 of the 2000 leaves, and compute
# the exact dot products of the top 100 candidates from asymmetric
# hashing to get the final top 10 candidates.
start = time.time()
neighbors, distances = searcher.search_batched(queries)
end = time.time()

# we are given top 100 neighbors in the ground truth, so select top 10
print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.8999
Time: 1.3812487125396729


In [8]:
# increasing the leaves to search increases recall at the cost of speed
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=150)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.92327
Time: 1.8380558490753174


In [9]:
# increasing reordering (the exact scoring of top AH candidates) has a similar effect.
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.93098
Time: 2.2772152423858643


In [10]:
# we can also dynamically configure the number of neighbors returned
# currently returns 10 as configued in ScannBuilder()
neighbors, distances = searcher.search_batched(queries)
print(neighbors.shape, distances.shape)

# now returns 20
neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)
print(neighbors.shape, distances.shape)

(10000, 10) (10000, 10)
(10000, 20) (10000, 20)


In [11]:
# we have been exclusively calling batch search so far; the single-query call has the same API
start = time.time()
neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)
end = time.time()

print(neighbors)
print(distances)
print("Latency (ms):", 1000*(end - start))

[ 97478 262700 846101 671078 232287]
[2.5518737 2.542952  2.539792  2.5383418 2.519638 ]
Latency (ms): 0.7724761962890625
