In [1]:
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple, Dict, List

import numpy as np
from tqdm.auto import tqdm

In [68]:
# pointA: (1, D); documents: (N, D)
def distance(pointA: np.ndarray, documents: np.ndarray) -> np.ndarray:
    return np.linalg.norm(pointA - documents, axis=1).reshape(-1, 1)

In [69]:
pointA = np.random.rand(1, 3)
documents = np.random.rand(10, 3)

In [70]:
dist = distance(pointA, documents)

In [71]:
dist.shape

(10, 1)

In [6]:
dist

array([[0.92746416],
       [0.81935806],
       [0.74958029],
       [0.68186471],
       [1.1651381 ],
       [0.73913866],
       [1.10971363],
       [0.83560655],
       [1.08783558],
       [0.77737634]])

In [7]:
np.argsort(dist, axis=0)

array([[3],
       [5],
       [2],
       [9],
       [1],
       [7],
       [0],
       [8],
       [6],
       [4]], dtype=int64)

In [8]:
dist[np.argsort(dist, axis=0)]

array([[[0.68186471]],

       [[0.73913866]],

       [[0.74958029]],

       [[0.77737634]],

       [[0.81935806]],

       [[0.83560655]],

       [[0.92746416]],

       [[1.08783558]],

       [[1.10971363]],

       [[1.1651381 ]]])

In [9]:
np.argsort(dist, axis=0)[-5:]

array([[7],
       [0],
       [8],
       [6],
       [4]], dtype=int64)

In [10]:
a = np.argsort(dist, axis=0).reshape(-1,)[:5]

In [11]:
np.random.choice(a, size=2, replace=False)

array([3, 2], dtype=int64)

In [12]:
{0: 1}

{0: 1}

In [13]:
a = [0, 1]
a.extend([2])
a

[0, 1, 2]

In [14]:
print([0, 1].extend([2]))

None


In [180]:
def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
        ) -> Dict[int, List[int]]:
    """
    creates small world graph with closest and furthest points
    :param data:
    :param num_candidates_for_choice_long: num furthest points from cur point
    :param num_edges_long: num sample from previous candidates
    :param num_candidates_for_choice_short: num closest points to cur point
    :param num_edges_short:
    :param use_sampling: if we have too much data we can use sampling
    :param sampling_share: proportion of data to sample
    :param dist_f: function for distance (euclidean)
    :return: sw graph
    """
    # my code here
    graph = {}
    num_points = data.shape[0]
    for i, point in enumerate(data):
        if not use_sampling:
            all_dists = dist_f(data[i, :], data)
            argsorted = np.argsort(all_dists.reshape(1, -1))[0][1:]
        else:
            sample_size = int(num_points * sampling_share)
            choiced = np.random.choice(
                list(range(num_points)), size=sample_size, replace=False)
            part_dists = dist_f(data[i, :], data[choiced, :])
            argsorted = choiced[np.argsort(part_dists.reshape(1, -1))[0][1:]]
        
        candidates_for_point = []
        further_points = argsorted[-num_candidates_for_choice_long:]
        further_points = np.random.choice(further_points, size=num_edges_long,
                                          replace=False)
        candidates_for_point.extend(list(further_points))

        closer_points = argsorted[:num_candidates_for_choice_short]
        closer_points = np.random.choice(closer_points, size=num_edges_short,
                                         replace=False)
        candidates_for_point.extend(list(closer_points))
        graph[i] = candidates_for_point
    return graph

In [181]:
documents = np.random.rand(10000, 5)

In [182]:
documents[0].reshape(1, -1).shape

(1, 5)

In [183]:
graph = create_sw_graph(documents)

In [184]:
graph[0]

[9904, 8448, 9091, 5774, 5938, 3370, 6304, 6286, 367, 4553]

In [186]:
# check graph. First five docs are most distant ones
# last 5 are most closest ones. Everything is ok
np.linalg.norm(documents[0] - documents[9904]), np.linalg.norm(documents[0] - documents[4553]),

(1.441386070731166, 0.1778228617595581)

In [228]:
def calc_dist_and_upd(all_visited_points: dict,
                      query_point: np.ndarray,
                      all_documents: np.ndarray,
                      point_idx: int,
                      dist_f: Callable
                      ) -> Tuple[float, bool]:
    """
    calculate distance between query point and specific point from all docs
    :param all_visited_points: if spec point in visited we will not calculate
    :param query_point:
    :param all_documents:
    :param point_idx:
    :param dist_f:
    :return: distance and is_visited flag
    """
    if point_idx in all_visited_points:
        return all_visited_points[point_idx], True
    cur_dist = \
        dist_f(query_point, all_documents[point_idx, :].reshape(1, -1))[0][0]
    all_visited_points[point_idx] = cur_dist
    return cur_dist, False


def nsw(query_point: np.ndarray,
        all_documents: np.ndarray,
        graph_edges: Dict,
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    """
    do navigable search with help of small world graph
    :param query_point: input for which we want to find closest documents
    :param all_documents: all available data
    :param graph_edges: small world graph
    :param search_k: num of output documents
    :param num_start_points: 5 init point to start a search through graph
    :param dist_f: eucledian distance
    :return: approximate top k closest docs
    """
    all_visited_points = {}
    num_started_points = 0
    while (num_started_points < num_start_points) or \
            (len(all_visited_points) < search_k):
        cur_point_idx = np.random.randint(0, all_documents.shape[0] - 1)
        cur_dist, is_visited = calc_dist_and_upd(
            all_visited_points, query_point, all_documents,
            cur_point_idx, dist_f)
        if is_visited:
            continue

        while True:
            min_dist = cur_dist
            choiced_cand = cur_point_idx

            cands_idxs = graph_edges[cur_point_idx]
            visited_before_cands = {cur_point_idx}
            for cand_idx in cands_idxs:
                tmp_d, is_visited = calc_dist_and_upd(
                    all_visited_points, query_point, all_documents,
                    cand_idx, dist_f)
                if tmp_d < min_dist:
                    min_dist = tmp_d
                    choiced_cand = cand_idx
                if is_visited:
                    visited_before_cands.add(cand_idx)

            if choiced_cand in visited_before_cands:
                break
            cur_dist = min_dist
            cur_point_idx = choiced_cand

        num_started_points += 1

    best_idxs = np.argsort(list(all_visited_points.values()))[:search_k]
    final_idx = np.array(list(all_visited_points.keys()))[best_idxs]
    return final_idx

In [229]:
query = np.array([1., 1., 1., 1., 1.])

In [230]:
documents = np.random.rand(10000, 5)

In [231]:
candidates = nsw(query, documents, graph)

In [232]:
documents[56].shape

(5,)

In [233]:
candidates

array([2755, 8470, 3277, 3983,  363, 3112, 1509,  147, 2493, 5103],
      dtype=int64)

In [250]:
np.linalg.norm(query - documents[2755]), np.linalg.norm(query - documents[5103])

(0.5181845590310753, 1.0078735656227376)

In [235]:
documents[2755]

array([0.98787889, 0.7856561 , 0.77278463, 0.65625646, 0.77056905])

In [236]:
all_dists = distance(query, documents)
argsorted = np.argsort(all_dists.reshape(1, -1))[0][1:]

In [237]:
argsorted[:20]

array([7818, 1262, 9029, 2149, 7929, 4738, 8359, 1001, 9460, 2115, 7413,
       1205, 9907, 4175, 3855,  173, 1810, 8767, 8897, 4555], dtype=int64)

In [253]:
# as you can see 5 start points is too small to find most closest docs
np.linalg.norm(query - documents[7818])

0.24574162882912523

In [241]:
documents[9503]

array([0.43784773, 0.05931726, 0.00223698, 0.51774528, 0.05959232])

In [246]:
candidates = nsw(query, documents, graph, num_start_points=30)

In [247]:
candidates

array([3153, 9761, 8146, 4460, 9591, 2012, 2487, 3651, 6840, 7579],
      dtype=int64)

In [252]:
# it's better now
np.linalg.norm(query - documents[3153]), np.linalg.norm(query - documents[7579])

(0.41613454891452917, 0.6182968005997975)