In [157]:
from collections import OrderedDict, defaultdict
from typing import Callable, Tuple, Dict, List, OrderedDict

import numpy as np
from numpy.linalg import norm
from tqdm.auto import tqdm

In [168]:
def distance(pointA: np.ndarray, documents: np.ndarray) -> np.ndarray:
    dist = np.linalg.norm(pointA-documents, axis = 1, keepdims = True)
    return dist

def create_sw_graph(
        data: np.ndarray,
        num_candidates_for_choice_long: int = 10,
        num_edges_long: int = 5,
        num_candidates_for_choice_short: int = 10,
        num_edges_short: int = 5,
        use_sampling: bool = False,
        sampling_share: float = 0.05,
        dist_f: Callable = distance
    ) -> Dict[int, List[int]]:
 
    sw_graph = dict()
    for in_i, i in enumerate(data):
        distanses = dist_f(i, data).squeeze()
        distanses = distanses.argsort(axis=0)[1:]
        top_long_dist = distanses[-num_candidates_for_choice_long:][::-1]
        top_l_dist = np.random.choice(top_long_dist, num_edges_long, replace=False)
        top_short_dist = distanses[:num_candidates_for_choice_short]
        top_sh_dist = np.random.choice(top_short_dist, num_edges_short, replace=False)
        sw_graph[in_i] = np.concatenate([top_l_dist, top_sh_dist])
    return sw_graph

#     graph_dct: OrderedDict[str, int] = OrderedDict()
#     if use_sampling:
#         N = len(data)
#         sampled_size = int(sampling_share*N)
#         sampled_indx = np.random.choice(range(N), sampled_size)
#         data = np.take(data, sampled_indx, axis=0)
        
#     N, _ = data.shape
#     data_indxs = list(range(N))

#     for indx, point in tqdm(enumerate(data), total=N):
#         pntA = data[indx]
#         others_indx = np.array(list(set(data_indxs) - set({indx})), dtype=np.int32)
# #         print(others_indx)
#         dists_arr = dist_f(pntA, np.take(data, others_indx, axis=0))
#         dist2indx = np.concatenate((dists_arr, others_indx.reshape(-1,1)),axis=1)
#         sorted_indx = np.argsort(dist2indx, axis=0)
#         pos = dist2indx[sorted_indx, 1].reshape(-1).astype(int)
#         long_indx = np.random.choice(pos[-num_candidates_for_choice_long:], 
#                                      size=num_edges_long, 
#                                      replace=False
#                                     )
#         short_indx = np.random.choice(pos[:num_candidates_for_choice_short], 
#                                       size=num_edges_short,
#                                       replace=False
#                                      )        
#         graph_dct[indx] = list(long_indx)+list(short_indx)
    
#     return graph_dct

def nsw_(query_point: np.ndarray, all_documents: np.ndarray,
        graph_edges: Dict[int, List[int]],
        search_k: int = 10, num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    # допишите ваш код здесь
    query = np.random.choice(range(all_documents.shape[0]), num_start_points)
    visited_vertex = dict()
    for i in query:
        distances = dist_f(query_point, all_documents[graph_edges[i]]).squeeze()
        #print(graph_edges[i])
        #print(distances)
        visited_vertex.update(list(zip(graph_edges[i], distances)))
#         print(f"{i} #{list(visited_vertex.keys())}")
    sorted_vertex = np.array(sorted(list(visited_vertex.items()), key = lambda x: x[1]))[:, 0][:search_k]
    
    return sorted_vertex.astype(int)

def nsw(query_point: np.ndarray, 
        all_documents: np.ndarray, 
        graph_edges: Dict[int, List[int]],
        search_k: int = 10, 
        num_start_points: int = 5,
        dist_f: Callable = distance) -> np.ndarray:
    
    N = list(graph_edges.keys())
    numOfPassedPnt = 10
    
#     np.random.seed(0)
    
    rnd_start = np.random.choice(N, num_start_points, replace=False) 
    visited_vertex = dict()
    for pnt in rnd_start:
        ref_pnt = pnt
        cnt = 0
        while True:
            out_dist = dist_f(query_point, all_documents[graph_edges[ref_pnt]]).squeeze()
            visited_vertex.update(list(zip(graph_edges[ref_pnt], out_dist)))
            if (cnt >= numOfPassedPnt):
                break
            sorted_indx = np.array(sorted(list(visited_vertex.items()), 
                                                    key = lambda x: x[1])
                                            )[0, 0]           
            ref_pnt = int(sorted_indx)
            cnt+=1
        sorted_vertex = np.array(sorted(list(visited_vertex.items()), 
                                        key = lambda x: x[1])
                                )[:, 0][:search_k]
    return sorted_vertex.astype(int)

# def nsw(query_point: np.ndarray, 
#         all_documents: np.ndarray, 
#         graph_edges: Dict[int, List[int]],
#         search_k: int = 10, 
#         num_start_points: int = 5,
#         dist_f: Callable = distance) -> np.ndarray:
    
#     all_nearest_pnts = OrderedDict()
#     nearest_pnts = OrderedDict()
#     N = list(graph_edges.keys())
#     numOfPassedPnt = 5
#     numOfMaxRandomization = 5
#     allfound = False
    
# #     np.random.seed(0)
#     rnd_ini_cnt = 0
    
#     while True:
#         rnd_start = np.random.choice(N, num_start_points, replace=False) 
#         rnd_ini_cnt += 1
#         for pnt in rnd_start:
#             nearest_pnts = OrderedDict()
#             ref_pnt = pnt
#             cnt = 0
#             while True:
#                 cnt+=1
#                 out_dist = distance(query_point, all_documents[graph_edges[ref_pnt]]).squeeze()
#                 sorted_indx = np.argsort(out_dist)
#                 dist_n, ind_n = out_dist[sorted_indx[0]], graph_edges[ref_pnt][sorted_indx[0]]
# #                 print(pnt, dist_n, ind_n, "##", ", ".join([str(k) for k, _ in all_nearest_pnts.items()]))
#                 if (len(nearest_pnts) == search_k) | (cnt >= numOfPassedPnt):
#                     tup = [(k,v) for k,v in nearest_pnts.items()]
#                     local_nearest = sorted(tup, key=lambda x: x[1])[0]
#                     all_nearest_pnts.update({local_nearest[0]:local_nearest[1]})
#                     break
#                 elif len(nearest_pnts)==0:
#                     nearest_pnts[ind_n] = dist_n
#                 elif (dist_n < np.array([v for _,v in nearest_pnts.items()]).min()):
#                     nearest_pnts[ind_n] = dist_n
#                 elif rnd_ini_cnt >= numOfMaxRandomization:
#                     res = list(zip(graph_edges[ref_pnt], out_dist))
#                     _dct = dict(res)
#                     all_nearest_pnts.update(_dct)
#                     break
                    
#                 ref_pnt = ind_n
#             if len(all_nearest_pnts) >= search_k:
#                 allfound = True
#                 break
#         if allfound: break
            
#         if rnd_ini_cnt >= numOfMaxRandomization:
#             pass
# #             pntfromlastoop = sorted([(k,v) for k, _ in all_nearest_pnts.items()], lambda x: x[1])[-1][0]
           
#     return np.array([k for k, _ in all_nearest_pnts.items()])[:search_k]


In [159]:
N = 200
dim = 100
documents = np.random.randn(N,dim)
query = np.random.randn(1,dim)

In [153]:
diff = query.reshape(-1,1) - np.transpose(documents, (1,0))
diff.shape

(100, 200)

In [154]:
distances = distance(query, documents)
# distances_alt = distance_alt(query, documents)
# np.allclose(distances, distances)
distances.squeeze()

array([12.94096826, 12.34656274, 13.3176857 , 12.78798224, 12.84057816,
       13.76788397, 13.51852287, 12.37913119, 13.56270957, 13.35401403,
       14.47644703, 12.67214219, 12.78114363, 14.75969695, 12.92913957,
       13.78725913, 14.26709887, 13.25608621, 12.58773002, 13.38394433,
       15.52113904, 12.78835231, 13.88408675, 13.13441247, 13.70124271,
       13.35543822, 14.53143574, 14.54222885, 14.80041709, 14.70106368,
       13.18178027, 12.10342057, 13.60313734, 14.57790904, 13.18813521,
       13.81732067, 14.00934993, 13.30250725, 13.79863683, 14.68264012,
       12.76603164, 14.17642106, 14.68177373, 13.79574296, 13.95019656,
       12.9124668 , 12.95719454, 14.00435476, 15.61099068, 14.17653652,
       14.42813739, 12.77702516, 14.63529312, 13.84560216, 15.51758456,
       13.31238796, 13.71093988, 14.25359205, 14.02524398, 13.7125792 ,
       14.61573458, 14.36371592, 14.94988882, 15.1389256 , 12.82714306,
       13.50456823, 13.63751708, 14.08275188, 13.22887738, 14.41

In [160]:
sw_graph = create_sw_graph(documents, 
                           dist_f=distance, 
                           use_sampling=False)

In [33]:
for i in tqdm(range(500)):
    out = nsw(query, documents, sw_graph, num_start_points=5, search_k=50)
#     print(len(out))

  0%|          | 0/500 [00:00<?, ?it/s]

In [171]:
nsw(query, 
    documents, 
    sw_graph, 
    num_start_points=5, 
    search_k=13)

array([162, 125,   9, 120,   2, 169, 152,  44, 135,  56, 164, 176, 190])

In [172]:
nsw_(query, 
    documents, 
    sw_graph, 
    num_start_points=5, 
    search_k=13)

array([ 30,  86, 125,   2, 169, 164, 176, 131, 186,  59,  77, 114,  62])