In [2]:
import sys
import os.path

sys.path.append('../modules')
from nsw import Node, NSWGraph
from nsw_visualization import show_state
import data_gen as dg

In [3]:
def callback_searcher(leader, cloud):
    subset = data[np.array(cloud)[:,1].astype(int),:]
    clear_output(wait=True)
    show_state(data, subset, leader=leader.reshape(1, -1), target=target)

In [7]:
import time
import math
import numpy as np

grid = {
    "size": [2000, 10000, 100000],
    "dim": [2, 5, 10, 50, 100],
    "regularity_ratio": [1, 2, 5, 10],
    "multisearch": [5, 10, 20],
    "top": [10, 100, 1000]
}

test_size = 100
approx_constant = 10

times_build = {}
times_search = {}
accuracy = {}

In [None]:
for dim in grid["dim"]:
    for size in grid["size"]:
        data = np.random.rand(size, dim)
        test = np.random.rand(test_size, dim)
        print(f"Dataset for {size} samples with {dim} dimensions generated")
        data_with_classes = list((row, 0) for row in data)
        print("Data with labels is created")
        
        dist = data @ test.T
        
        # for each test sample
        closest = []
        for i in range(test.shape[0]):
            cl = np.argpartition(dist[:, i], max(grid["top"])).tolist()
            closest.append(cl)
        print("Ground truth is generated")
        
        for multisearch in grid["multisearch"]:
            for regularity_ratio in grid["regularity_ratio"]:
                regularity = math.ceil(math.log(size) * regularity_ratio)

                tpl = (dim, size, regularity, multisearch) 
                
                filename = f"../dumps/{dim}D_{size}items_{regularity}regular_{multisearch}repeat.graph"
                if os.path.exists(filename):
                    G = NSWGraph.load(filename)
                    print(f"Graph [K={regularity}, Repeat={multisearch}] is loaded from file.")
                else:
                    start = time.time()
                    G = NSWGraph()
                    G.build_navigable_graph(data_with_classes,  K=regularity, attempts=multisearch)
                    fin = time.time()
                    t = fin - start
                    G.save(filename)
                    print(f"Graph [K={regularity}, Repeat={multisearch}] is generated in {t:.2f} sec.")
                    times_build[tpl] = t
                
                for top in grid["top"]:
                    tpl = (dim, size, regularity, multisearch, top)                 
                    match, match_scaled, total, total_scaled = 0, 0, 0, 0
                    for i, row in enumerate(test):
                        start += time.time()
                        result = G.multi_search(row, attempts=multisearch, top=top)
                        fin += time.time()
                        
                        result = set(result)
                        intersect = len(result.intersection(closest[i][:top]))
                        intersect_scaled = len(result.intersection(closest[i][:top // approx_constant]))
                        match += intersect
                        match_scaled += intersect_scaled
                        total += top
                        total_scaled += top // approx_constant
                    accuracy[tpl] = (match, total)
                    print(f'top {top} ~ {100 * match / total:.2f}% ; scaled[{approx_constant}] ~ {100 * match_scaled / total_scaled:.2f}%')
                    times_search[tpl] = (fin - start) / test.shape[0] 

Dataset for 2000 samples with 2 dimensions generated
Data with labels is created
Ground truth is generated
Graph [K=8, Repeat=5] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 5.23% ; scaled[10] ~ 5.30%
top 1000 ~ 21.82% ; scaled[10] ~ 26.68%
Graph [K=16, Repeat=5] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 4.75% ; scaled[10] ~ 4.80%
top 1000 ~ 28.33% ; scaled[10] ~ 32.99%
Graph [K=39, Repeat=5] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 4.61% ; scaled[10] ~ 4.70%
top 1000 ~ 41.88% ; scaled[10] ~ 44.98%
Graph [K=77, Repeat=5] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 5.00% ; scaled[10] ~ 5.20%
top 1000 ~ 48.22% ; scaled[10] ~ 47.39%
Graph [K=8, Repeat=10] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 4.81% ; scaled[10] ~ 4.60%
top 1000 ~ 32.82% ; scaled[10] ~ 36.25%
Graph [K=16, Repeat=10] is loaded from file.
top 10 ~ 0.30% ; scaled[10] ~ 0.00%
top 100 ~ 4.73% ; scaled[10]