In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict
from range_index import create_range_index
from utils import parse_ann_benchmarks_hdf5
import numpy as np
import time
from tqdm import tqdm


filter_path = "/data/scratch/jae/ann_benchmarks_datasets/sift-128-euclidean_filters.npy"
data_path = "/data/scratch/jae/ann_benchmarks_datasets/sift-128-euclidean.hdf5"

index = create_range_index(data_path, filter_path)

queries = parse_ann_benchmarks_hdf5(data_path)[1]

if "angular" in data_path:
    queries = queries / np.linalg.norm(queries, axis=-1)[:, np.newaxis]


# TODO: Should also vary index quality
# TODO: Should also vary cutoff level
# TODO: Try on larger datasets


top_k = 10
output_file = "glove_experiment.txt"

with open(output_file, "a") as f:
    f.write("filter_width,method,recall,average_time\n")


In [38]:
for filter_width in [0.1]:
    run_results = defaultdict(list)
    for q in tqdm(queries[:100]):
        random_filter_start = np.random.uniform(0, 1 - filter_width)
        filter_range = (random_filter_start, random_filter_start + filter_width)

        start = time.time()
        gt = index.prefilter_query(q, top_k=top_k, filter_range=filter_range)
        run_results["prefiltering"].append((1, time.time() - start))
        # print("GT", gt)
        for postfilter_doubles in range(6):
            for query_complexity in [10, 20, 40, 80, 160, 320, 640]:
                start = time.time()
                our_result = index.query(
                    q,
                    top_k=top_k,
                    query_complexity=query_complexity,
                    filter_range=filter_range,
                    postfilter_doubles=postfilter_doubles
                )
                # print(our_result)
                run_results[f"ours_{query_complexity}_{postfilter_doubles}"].append(
                    (
                        len([x for x in gt[0] if x in our_result[0]]) / len(gt[0]),
                        time.time() - start,
                    )
                )

        for extra_doubles in range(6):
            start = time.time()
            postfilter_result = index.postfilter_query(
                q, top_k=top_k, filter_range=filter_range, extra_doubles=extra_doubles
            )
            # print(filter_range, [index.filter_values[i] for i in gt[0]], [index.filter_values[i] for i in postfilter_result[0]])
            run_results[f"postfiltering_{extra_doubles}"].append(
                (
                    len([x for x in gt[0] if x in postfilter_result[0]]) / len(gt[0]),
                    time.time() - start,
                )
            )

    with open(output_file, "a") as f:
        for name, zipped_recalls_times in run_results.items():
            recalls = [r for r, _ in zipped_recalls_times]
            times = [t for _, t in zipped_recalls_times]
            f.write(f"{filter_width},{name},{np.mean(recalls)},{np.mean(times)}\n")

100%|██████████| 1000/1000 [01:32<00:00, 10.82it/s]
100%|██████████| 1000/1000 [02:42<00:00,  6.17it/s]
100%|██████████| 1000/1000 [05:56<00:00,  2.80it/s]
100%|██████████| 1000/1000 [06:58<00:00,  2.39it/s]
100%|██████████| 1000/1000 [07:33<00:00,  2.20it/s]
100%|██████████| 1000/1000 [08:43<00:00,  1.91it/s]
 25%|██▌       | 252/1000 [02:25<07:13,  1.73it/s]


KeyboardInterrupt: 