In [8]:
import sys
import os
from pathlib import Path
import numpy as np
import time

# wonky but works
sys.path.append(str(Path('./ParlayANN/python').resolve()))

import wrapper as wp

from utils import *

In [9]:
THREADS = 96
os.environ["PARLAY_NUM_THREADS"] = str(THREADS)

In [10]:
# loading data
data_dir = "/ssd1/anndata/ann-benchmarks/"
dataset_name = "sift-128-euclidean"

data_path = os.path.join(data_dir, f"{dataset_name}.hdf5")
filter_path = os.path.join(data_dir, f"{dataset_name}_filters.npy")

data = parse_ann_benchmarks_hdf5(data_path)[0]
filter_values = np.load(filter_path)

queries = parse_ann_benchmarks_hdf5(data_path)[1]

if 'angular' in dataset_name:
    metric = "mips"
    # normalize data
    data = data / np.linalg.norm(data, axis=-1)[:, np.newaxis]
else:
    metric = "Euclidian"

In [11]:
# do prefiltering index build
print("prefiltering index build")
prefilter_constructor = wp.prefilter_index_constructor(metric, 'float')
prefilter_build_start = time.time()
prefilter_index = prefilter_constructor(data, filter_values)
prefilter_build_end = time.time()
prefilter_build_time = prefilter_build_end - prefilter_build_start
print(f"prefiltering index build time: {prefilter_build_time:.3f}s")

prefiltering index build
prefiltering index build time: 0.433s


In [12]:
# build index
constructor = wp.range_filter_tree_index_constructor(metric, 'float')
print("building index")
index_build_start = time.time()
index = constructor(data, filter_values, 1_000)
index_build_end = time.time()
index_build_time = index_build_end - index_build_start
print(f"index build time: {index_build_time:.3f}s")

building index
index build time: 1.634s


In [13]:
top_k = 10
filter_width = 0.25

raw_filters = np.random.uniform(filter_width / 2, 1 - filter_width / 2, size=len(queries))

filters = np.array([(x - filter_width / 2, x + filter_width / 2) for x in raw_filters])

In [14]:
print("prefilter querying")
prefiltering_start = time.time()
prefilter_results = prefilter_index.batch_query(queries, filters, queries.shape[0], top_k)
prefiltering_end = time.time()
prefiltering_time = prefiltering_end - prefiltering_start
print(f"prefiltering time: {prefiltering_time:.3f}s")

prefilter querying
prefiltering time: 20.246s


In [15]:
print("index querying")
start = time.time()
index_results = index.batch_filter_search(queries, filters, queries.shape[0], top_k)
end = time.time()
index_time = end - start
print(f"index time: {index_time:.3f}s")

index querying
index time: 20.869s


In [16]:
def compute_recall(gt_neighbors, results, top_k):
    recall = 0
    for i in range(len(gt_neighbors)): # for each query
        gt = set(gt_neighbors[i])
        res = set(results[i][:top_k])
        recall += len(gt.intersection(res)) / len(gt)
    return recall / len(gt_neighbors) # average recall per query

In [18]:
# compute recall
index_recall = compute_recall(prefilter_results[0], index_results[0], top_k)
print(f"index recall: {index_recall*100:.2f}%")

# compute average time
index_average_time = index_time / queries.shape[0]
prefilter_average_time = prefiltering_time / queries.shape[0]
print(f"index average time: {index_average_time*1000:.2f}ms")
print(f"prefilter average time: {prefilter_average_time*1000:.2f}ms")

# compute qps
index_qps = queries.shape[0] / index_time
prefilter_qps = queries.shape[0] / prefiltering_time
print(f"index qps: {index_qps:.2f}")
print(f"prefilter qps: {prefilter_qps:.2f}")

index recall: 99.98%
index average time: 2.09ms
prefilter average time: 2.02ms
index qps: 479.19
prefilter qps: 493.94
