In [1]:
!pip install torch
!pip install tqdm



In [73]:
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple, Dict, List
from queue import PriorityQueue
import numpy as np
from tqdm.auto import tqdm


def distance(pointA: np.ndarray, documents: np.ndarray) -> np.ndarray:
    matrix = np.repeat(pointA, len(documents), axis=0)
    res = (documents - matrix)**2
    res = np.sqrt(np.sum(res, axis=1))
    return res.reshape(res.shape[0], 1)


def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
    ) -> Dict[int, List[int]]:
    res = {}
    for i, point in enumerate(data):
        links = []
        if use_sampling:
            distancetoAll =  dist_f(point, data[np.random.randint(0, len(data) - 1, int(len(data) * sampling_share))])
        else:
            distancetoAll = dist_f(point.reshape((1, len(point))), data)
        sortedDistance = np.argsort(distancetoAll.reshape(len(distancetoAll)))
        sortedDistance = sortedDistance[sortedDistance != i]
        links = [*sortedDistance[-num_edges_long:], *sortedDistance[:num_edges_short]]
        res.update({i: links})
    return res
        

def control_queue(queue: list, visited_vertex: dict, search_k: int, all_vertex: list):
    if queue and len(visited_vertex) < search_k:
        remainder = list(set(all_vertex).difference(set(visited_vertex.keys())))
        queue.append(np.random.choice(remainder, 1)[0])

def nsw(query_point: np.ndarray, all_documents: np.ndarray, 
        graph_edges: Dict[int, List[int]],
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    
    all_vertex = list(range(all_documents.shape[0]))
    
    queue = list(np.random.choice(all_vertex, num_start_points, replace=False))
    visited_vertex = dict()
    
    while queue:
        point = queue.pop()
        if point in visited_vertex:
            control_queue(queue, visited_vertex, search_k, all_vertex)
            continue
        else:
            neighbours = []
            for neighbour in graph_edges[point]:
                if neighbour in visited_vertex:
                    continue
                neighbours.append(neighbour)
            distances = dist_f(query_point, all_documents[neighbours]).squeeze()
            if len(neighbours) == 1:
                distances = [distances]
            visited_vertex.update(list(zip(neighbours, distances)))
            queue.extend(neighbours)
        control_queue(queue, visited_vertex, search_k, all_vertex)
        
    nearest = list(zip(*sorted(visited_vertex.items(), key=lambda x: x[1])))[0][:search_k]
    return nearest


In [78]:
def trueNearest(
    query_point: np.ndarray, 
    all_documents: np.ndarray
) -> np.ndarray:
    dist = distance(query, all_documents)
    indice = np.argsort(dist.reshape(dist.shape[0]))[:20]
    return indice

In [79]:
data = np.random.randint(0, 400, (10000, 4))
query = np.random.randint(0, 400, (1, 4))
edges = create_sw_graph(data)

aprox = nsw(query, data, edges)
true = trueNearest(query, data)

print(data[np.sort(true)])
print(data[np.sort(aprox)])
print(f'{len(np.intersect1d(true, aprox)) / 10 * 100}%')

[[113 259 121   2]
 [ 85 268 111   2]
 [ 61 309 120  17]
 [116 257  78  42]
 [ 94 243  82  27]
 [135 259 117  25]
 [100 254 122  24]
 [ 76 237 139  65]
 [ 89 310 127  75]
 [118 298 117  38]
 [ 74 285 105  99]
 [ 65 306  74  46]
 [ 48 288 136  73]
 [ 56 260 101  57]
 [120 323 117  39]
 [ 81 276  74  10]
 [ 85 243 116  73]
 [109 251 117  89]
 [119 270 101  69]
 [ 85 241 127  50]]
[[116 257  78  42]
 [116 257  78  42]
 [ 67 311  55 294]
 [ 76 191 160  22]
 [ 54 242 172 125]
 [189 284 179  54]
 [ 41 313  77 301]
 [ 41 313  77 301]
 [205 121  94 120]
 [ 85 241 127  50]]
20.0%
