# GloVe数据示例

## 安装导入依赖环境
> python=3.8
> 
> pip install scann h5py requests

In [None]:
import os
import time
import tempfile

import numpy as np
import h5py
import requests
import scann

## 下载示例数据

In [None]:
with tempfile.TemporaryDirectory() as tmp:
    response = requests.get("http://ann-benchmarks.com/glove-100-angular.hdf5")
    loc = os.path.join(tmp, "glove.hdf5")
    with open(loc, 'wb') as f:
        f.write(response.content)
    
    glove_h5py = h5py.File(loc, "r")
    list(glove_h5py.keys())

In [None]:
dataset = glove_h5py['train']
queries = glove_h5py['test']
print(dataset.shape)
print(queries.shape)

## 创建ScaNN索引

In [None]:
normalized_dataset = dataset / np.linalg.norm(dataset, axis=1)[:, np.newaxis]
# 配置ScaNN为树-非对称哈希混合与重排序各向异性量化
# 使用 scann.scann_ops.build() 创建一个兼容 TensorFlow-compatible 的搜索器
searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product").tree(
    num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000).score_ah(
    2, anisotropic_quantization_threshold=0.2).reorder(100).build()

In [None]:
def compute_recall(neighbors, true_neighbors):
    total = 0
    for gt_row, row in zip(true_neighbors, neighbors):
        total += np.intersect1d(gt_row, row).shape[0]
    return total / true_neighbors.size

## ScaNN 接口特性

In [None]:
# 它将搜索2000个叶节点中的前100个，
# 并通过非对称哈希计算前100个候选节点的精确点乘
start = time.time()
neighbors, distances = searcher.search_batched(queries)
end = time.time()

# 选取top10
print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

In [None]:
# 增加搜索的叶子会以速度为代价增加成本
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=150)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

In [None]:
# 增加排序(顶级AH候选的精确评分)也有类似的效果
start = time.time()
neighbors, distances = searcher.search_batched(queries, leaves_to_search=150, pre_reorder_num_neighbors=250)
end = time.time()

print("Recall:", compute_recall(neighbors, glove_h5py['neighbors'][:, :10]))
print("Time:", end - start)

In [None]:
# 动态选择批量搜索的数量
neighbors, distances = searcher.search_batched(queries)
print(neighbors.shape, distances.shape)

# 批量搜索20个
neighbors, distances = searcher.search_batched(queries, final_num_neighbors=20)
print(neighbors.shape, distances.shape)

In [None]:
# 单个搜索
start = time.time()
neighbors, distances = searcher.search(queries[0], final_num_neighbors=5)
end = time.time()

print(neighbors)
print(distances)
print("Latency (ms):", 1000*(end - start))