<a href="https://colab.research.google.com/github/arnaudmkonan/Transformers-text-classification/blob/master/scann_scalable_search_example_ipynb_txt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Copyright 2022 Google LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ScaNN Demo with GloVe Dataset

In [39]:
!pip -q install scann

In [2]:
import numpy as np
import h5py
import os
import requests
import tempfile
import time

import scann

### Download dataset

In [4]:
!wget http://ann-benchmarks.com/nytimes-256-angular.hdf5

--2022-10-19 05:31:23--  http://ann-benchmarks.com/nytimes-256-angular.hdf5
Resolving ann-benchmarks.com (ann-benchmarks.com)... 52.216.44.253
Connecting to ann-benchmarks.com (ann-benchmarks.com)|52.216.44.253|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 315208288 (301M) [binary/octet-stream]
Saving to: ‘nytimes-256-angular.hdf5’


2022-10-19 05:31:30 (43.6 MB/s) - ‘nytimes-256-angular.hdf5’ saved [315208288/315208288]



In [20]:
with tempfile.TemporaryDirectory() as tmp:
    response = requests.get("http://ann-benchmarks.com/glove-100-angular.hdf5")
    # response = requests.get("http://ann-benchmarks.com/nytimes-256-angular.hdf5")
    # print(response.text)
    loc = os.path.join(tmp, "glove.hdf5")
    # loc = os.path.join("nytimes-256-angular.hdf5")
    with open(loc, 'wb') as f:
        f.write(response.content)
    
    glove_h5py = h5py.File(loc, "r")

In [21]:
loc = os.path.join("nytimes-256-angular.hdf5")
with open(loc, 'wb') as f:
  f.write(response.content)
  nyt_256_angular_h5py = h5py.File(loc, "r")
list(nyt_256_angular_h5py.keys())

['distances', 'neighbors', 'test', 'train']

In [22]:
list(glove_h5py.keys())

['distances', 'neighbors', 'test', 'train']

In [28]:
nyt_256_angular_h5py['train']


<HDF5 dataset "train": shape (290000, 256), type "<f4">

In [29]:
# glove_h5py = nyt_256_angular_h5py

In [25]:
dataset = glove_h5py['train']
queries = glove_h5py['test']
print(dataset.shape)
print(queries.shape)

(1183514, 100)
(10000, 100)


In [26]:
dataset.file

<HDF5 file "glove.hdf5" (mode r)>

### Create ScaNN searcher

In [27]:
normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
# configure ScaNN as a tree - asymmetric hash hybrid with reordering
# anisotropic quantization as described in the paper; see README

# use scann.scann_ops.build() to instead create a TensorFlow-compatible searcher
searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

In [30]:
def compute_recall(neighbors, true_neighbors):
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size

### ScaNN interface features

In [38]:
glove_h5py['neighbors'][:, :10]

array([[  97478,  262700,  846101, ..., 1133489,  723915,  660281],
       [ 875925,  903728,  144313, ...,  675600,  891287,  712921],
       [1046944,  809599,  531832, ...,  988527,  377259,  625676],
       ...,
       [1108312,  330498,  945288, ...,  350756, 1180096,  196396],
       [ 214774, 1024728, 1114909, ...,  793539,  958245,  699403],
       [ 423309,  484674, 1139759, ...,  206789,  804109,  974574]],
      dtype=int32)

In [31]:
# this will search the top 100 of the 2000 leaves, and compute
# the exact dot products of the top 100 candidates from asymmetric
# hashing to get the final top 10 candidates.
start = time.time()
neighbors, distances = searcher.search_batched(queries)
end = time.time()

# we are given top 100 neighbors in the ground truth, so select top 10
print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.89965
Time: 2.1338512897491455


In [37]:
# increasing the leaves to search increases recall at the cost of speed
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=1000)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.98051
Time: 12.514467716217041


In [33]:
# increasing reordering (the exact scoring of top AH candidates) has a similar effect.
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

Recall: 0.93143
Time: 3.405034303665161


In [34]:
# we can also dynamically configure the number of neighbors returned
# currently returns 10 as configued in ScannBuilder()
neighbors, distances = searcher.search_batched(queries)
print(neighbors.shape, distances.shape)

# now returns 20
neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)
print(neighbors.shape, distances.shape)

(10000, 10) (10000, 10)
(10000, 20) (10000, 20)


In [35]:
# we have been exclusively calling batch search so far; the single-query call has the same API
start = time.time()
neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)
end = time.time()

print(neighbors)
print(distances)
print("Latency (ms):", 1000*(end - start))

[ 97478 846101 671078 727732 544474]
[2.5518737 2.539792  2.5383418 2.5097368 2.4656374]
Latency (ms): 4.87065315246582


In [36]:
queries[0]

array([ 0.39553  ,  0.23048  ,  0.82722  ,  0.10453  , -0.69281  ,
       -0.83357  , -0.49049  , -0.036362 , -0.48396  , -0.44315  ,
       -0.37407  , -0.13825  ,  0.3158   ,  0.16467  ,  0.1318   ,
       -0.34739  ,  0.30084  ,  0.26194  ,  0.60956  , -0.21171  ,
        0.26935  , -0.56669  ,  0.34927  ,  0.34816  , -0.014743 ,
        0.97688  ,  0.17702  ,  0.16185  ,  0.044074 , -0.68819  ,
        0.18073  ,  0.26355  ,  0.36275  , -0.73523  ,  0.39962  ,
        0.0037411, -0.15352  ,  0.10079  , -0.23187  , -0.7068   ,
        0.32768  , -0.012518 ,  0.038887 ,  0.67385  , -1.1839   ,
        0.91321  , -0.0060804,  0.026679 ,  0.42256  , -0.10934  ,
       -0.25663  , -0.22761  ,  0.34171  , -0.47256  , -0.075018 ,
       -0.55013  ,  0.5073   ,  0.096439 , -0.14561  ,  0.21227  ,
       -0.82953  ,  0.33062  ,  0.064787 ,  0.106    , -0.25982  ,
        0.24861  ,  0.2334   ,  0.45757  , -0.38603  , -0.19482  ,
       -0.83137  , -0.097219 , -0.23189  ,  0.21918  , -0.6416