In [286]:
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple, Dict, List

import numpy as np
from tqdm.auto import tqdm


def distance(pointA: np.ndarray, documents: np.ndarray) -> np.ndarray:
    return np.linalg.norm(pointA - documents, axis=1).reshape(-1,1)
#     return np.linalg.norm(pointA - documents, axis=1, keepdims=True) # или так

    # where the system will automatically compute the correct shape instead "-1"
#     происходит автоматический broadcast (из матрицы 1*D вычитается матрица N*D)

def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
    ) -> Dict[int, List[int]]:
    
    ind_dict = defaultdict(list)
#     ind_dict[12].append(...) # можно обращаться к несуществующему ключу

#     np.random.seed(1)
    
    for i, r in enumerate(data):
        
        if use_sampling:
            choose_n = int(sampling_share * len(data))
            random_ind = np.random.choice(np.arange(len(data)), choose_n, replace=False)
            dist = dist_f(r, data[random_ind])
        else:
            dist = dist_f(r, data)
            
        cand_ind_sort_long = np.argsort(dist, axis=0)[::-1][:-1][:num_candidates_for_choice_long]
        cand_ind_sort_long = cand_ind_sort_long.reshape(-1,)
        top_ind_sort_long = np.random.choice(cand_ind_sort_long, num_edges_long, replace=False)
        
        cand_ind_sort_short = np.argsort(dist, axis=0)[1:][:num_candidates_for_choice_short]
        cand_ind_sort_short = cand_ind_sort_short.reshape(-1,)
        top_ind_sort_short = np.random.choice(cand_ind_sort_short, num_edges_short, replace=False)
        
        short_and_long_ind = np.concatenate((top_ind_sort_long, top_ind_sort_short))
        
        if use_sampling:
            short_and_long_ind = data[random_ind][short_and_long_ind]
            
        ind_dict[i] = list(short_and_long_ind)
        
    return ind_dict

# ideal
def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
    ) -> Dict:
    edges = defaultdict(list)
    num_points = data.shape[0]

    for cur_point_idx in tqdm(range(num_points)):
        if not use_sampling:
            all_dists = dist_f(data[cur_point_idx, :], data)
            argsorted = np.argsort(all_dists.reshape(1, -1))[0][1:]
        else:
            sample_size = int(num_points * sampling_share)
            choiced = np.random.choice(
                list(range(num_points)), size=sample_size, replace=False)
            part_dists = dist_f(data[cur_point_idx, :], data[choiced, :])
            argsorted = choiced[np.argsort(part_dists.reshape(1, -1))[0][1:]]

        short_cands = argsorted[:num_candidates_for_choice_short]
        short_choice = np.random.choice(
            short_cands, size=num_edges_short, replace=False)

        long_cands = argsorted[-num_candidates_for_choice_long:]
        long_choice = np.random.choice(
            long_cands, size=num_edges_long, replace=False)

        for i in np.concatenate([short_choice, long_choice]):
            edges[cur_point_idx].append(i)

    return dict(edges)   



def nsw(query_point: np.ndarray, all_documents: np.ndarray, 
        graph_edges: Dict[int, List[int]],
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
        
        node_connections = len(list(graph_edges.values())[0])
        print(node_connections)
        
        if search_k>num_start_points*node_connections:
            num_start_points = num_start_points+(search_k-num_start_points*node_connections)//node_connections+1
                        
#         np.random.seed(1)
        
        start_points_ind = np.random.choice(len(all_documents), num_start_points)
        
        top_k_ind_total = list()
        
        for curr_point_ind in start_points_ind:
        
            dist_curr_point = np.inf
            min_dist_iter = -np.inf

            while min_dist_iter < dist_curr_point:
                
                dist_curr_point = dist_f(query_point, all_documents[curr_point_ind])
                dist_curr_point = np.float(dist_curr_point)
                
                top_k_ind_total.append([curr_point_ind])
                
                conn_points = graph_edges[curr_point_ind]
                dist = dist_f(query_point, all_documents[conn_points])
                min_dist_iter = np.min(dist)
                min_dist_iter_ind = np.argmin(dist)
                curr_point_ind = conn_points[min_dist_iter_ind]
                
                top_k_ind = np.argsort(dist, axis=0).reshape(-1,)[:search_k]
                top_k_ind = [conn_points[i] for i in top_k_ind]
                top_k_ind_total.append(top_k_ind)
                
#                 print(f'dist_curr_point {dist_curr_point}')
#                 print(f'min_dist_iter {min_dist_iter}')
#                 print('\n')

#             print('----next random----')
        
        top_k_ind_total = [i for sl in top_k_ind_total for i in sl]
        top_k_ind_total = np.unique(top_k_ind_total)
        
        dist_ = dist_f(query_point, all_documents[top_k_ind_total])
        
        top_k_ind = np.argsort(dist_, axis=0).reshape(-1,)[:search_k]
                        
        final = [top_k_ind_total[i] for i in top_k_ind]
        
        return final

    
# ideal   
def calc_d_and_upd(all_visited_points: OrderedDict, query_point: np.ndarray,
                   all_documents: np.ndarray, point_idx: int, dist_f: Callable
                   ) -> Tuple[float, bool]:
    if point_idx in all_visited_points:
        return all_visited_points[point_idx], True
    cur_dist = dist_f(
        query_point, all_documents[point_idx, :].reshape(1, -1))[0][0]
    all_visited_points[point_idx] = cur_dist
    return cur_dist, False


def nsw(query_point: np.ndarray, all_documents: np.ndarray, graph_edges: Dict,
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    all_visited_points = OrderedDict()
    num_started_points = 0
    # pbar = tqdm(total=num_start_points)
    while ((num_started_points < num_start_points) or (len(all_visited_points) < search_k)):
        # pbar.update(1)
        cur_point_idx = np.random.randint(0, all_documents.shape[0]-1)
        cur_dist, verdict = calc_d_and_upd(
            all_visited_points, query_point, all_documents, cur_point_idx, dist_f)
        if verdict:
            continue

        while True:
            min_dist = cur_dist
            choiced_cand = cur_point_idx

            cands_idxs = graph_edges[cur_point_idx]
            true_verdict_cands = set([cur_point_idx])
            for cand_idx in cands_idxs:
                tmp_d, verdict = calc_d_and_upd(
                    all_visited_points, query_point, all_documents, cand_idx, dist_f)
                if tmp_d < min_dist:
                    min_dist = tmp_d
                    choiced_cand = cand_idx
                if verdict:
                    true_verdict_cands.add(cand_idx)
            else:
                if choiced_cand in true_verdict_cands:
                    break
                cur_dist = min_dist
                cur_point_idx = choiced_cand
                continue
            break
        num_started_points += 1

    best_idxs = np.argsort(list(all_visited_points.values()))[:search_k]
    final_idx = np.array(list(all_visited_points.keys()))[best_idxs]
    return final_idx

In [287]:
# point = np.random.rand(1,128)
# all_documents = np.random.rand(50,128)

In [288]:
ind_dict = create_sw_graph(data=all_documents, dist_f=distance)
graph_edges = ind_dict

In [289]:
out = nsw(point, all_documents, graph_edges, dist_f=distance, search_k=100, num_start_points=1)
out

10


[41,
 1,
 19,
 11,
 33,
 36,
 0,
 4,
 37,
 9,
 8,
 38,
 32,
 13,
 39,
 21,
 2,
 42,
 22,
 3,
 40,
 28,
 5,
 30,
 48,
 26,
 15,
 6,
 45,
 27,
 7,
 17,
 14,
 12,
 43,
 31,
 34,
 10,
 24,
 49,
 46,
 25,
 20,
 16,
 23,
 29,
 44]

In [285]:
len(out)

50

In [778]:
ind_dict = create_sw_graph(data=data, dist_f=distance)

In [773]:
start_points_ind = np.random.choice(len(data), 5)
for sp in start_points_ind:
    print(sp)

23
36
27
6
37


In [547]:
len(start_points_ind)

5

In [512]:
np.argsort(dist, axis=0)[1:]

array([[26],
       [45],
       [ 7],
       [40],
       [36],
       [ 3],
       [27],
       [20],
       [ 4],
       [15],
       [43],
       [39],
       [29],
       [34],
       [ 5],
       [24],
       [42],
       [47],
       [23],
       [ 8],
       [ 6],
       [38],
       [28],
       [35],
       [32],
       [31],
       [ 9],
       [19],
       [10],
       [41],
       [17],
       [11],
       [18],
       [22],
       [12],
       [ 1],
       [37],
       [13],
       [16],
       [30],
       [25],
       [ 2],
       [21],
       [44],
       [48],
       [49],
       [33],
       [46],
       [14]])